diff --git a/common/src/main/java/dev/openrs2/common/crypto/Rsa.kt b/common/src/main/java/dev/openrs2/common/crypto/Rsa.kt index 4916a1c4..fa27e02d 100644 --- a/common/src/main/java/dev/openrs2/common/crypto/Rsa.kt +++ b/common/src/main/java/dev/openrs2/common/crypto/Rsa.kt @@ -13,6 +13,7 @@ import org.bouncycastle.crypto.params.RSAKeyParameters import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters import org.bouncycastle.crypto.util.PrivateKeyInfoFactory import org.bouncycastle.crypto.util.SubjectPublicKeyInfoFactory +import org.bouncycastle.util.BigIntegers import org.bouncycastle.util.io.pem.PemObject import org.bouncycastle.util.io.pem.PemReader import org.bouncycastle.util.io.pem.PemWriter @@ -51,6 +52,64 @@ object Rsa { return Pair(keyPair.public as RSAKeyParameters, keyPair.private as RSAPrivateCrtKeyParameters) } + fun encrypt(plaintext: BigInteger, key: RSAKeyParameters): BigInteger { + require(!key.isPrivate) + return plaintext.modPow(key.exponent, key.modulus) + } + + private fun generateBlindingFactor(m: BigInteger): Pair { + val max = m - BigInteger.ONE + + while (true) { + val r = BigIntegers.createRandomInRange(BigInteger.ONE, max, secureRandom) + val rInv = try { + r.modInverse(m) + } catch (ex: ArithmeticException) { + continue + } + return Pair(r, rInv) + } + } + + fun decrypt(ciphertext: BigInteger, key: RSAKeyParameters): BigInteger { + require(key.isPrivate) + + if (key is RSAPrivateCrtKeyParameters) { + // blind the input + val e = key.publicExponent + val m = key.modulus + val (r, rInv) = generateBlindingFactor(m) + + val blindCiphertext = (r.modPow(e, m) * ciphertext).mod(m) + + // decrypt using the Chinese Remainder Theorem + val p = key.p + val q = key.q + val dP = key.dp + val dQ = key.dq + val qInv = key.qInv + + val mP = (blindCiphertext.mod(p)).modPow(dP, p) + val mQ = (blindCiphertext.mod(q)).modPow(dQ, q) + + val h = (qInv * (mP - mQ)).mod(p) + + val blindPlaintext = (h * q) + mQ + + // unblind output + val plaintext = (blindPlaintext * rInv).mod(m) + + // defend against CRT faults (see https://people.redhat.com/~fweimer/rsa-crt-leaks.pdf) + if (plaintext.modPow(e, m) != ciphertext) { + throw IllegalStateException() + } + + return plaintext + } else { + return ciphertext.modPow(key.exponent, key.modulus) + } + } + fun readPublicKey(path: Path): RSAKeyParameters { val der = readSinglePemObject(path, PUBLIC_KEY) diff --git a/common/src/test/java/dev/openrs2/common/crypto/RsaTest.kt b/common/src/test/java/dev/openrs2/common/crypto/RsaTest.kt new file mode 100644 index 00000000..9f2e7232 --- /dev/null +++ b/common/src/test/java/dev/openrs2/common/crypto/RsaTest.kt @@ -0,0 +1,54 @@ +package dev.openrs2.common.crypto + +import org.bouncycastle.crypto.params.RSAKeyParameters +import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters +import org.bouncycastle.util.Properties +import java.math.BigInteger +import kotlin.test.Test +import kotlin.test.assertEquals + +object RsaTest { + private const val ALLOW_UNSAFE_MOD = "org.bouncycastle.rsa.allow_unsafe_mod" + + @Test + fun testEncrypt() { + // from https://en.wikipedia.org/wiki/RSA_(cryptosystem)#Example + val public = allowUnsafeMod { RSAKeyParameters(false, BigInteger("3233"), BigInteger("17")) } + val ciphertext = Rsa.encrypt(BigInteger("65"), public) + assertEquals(BigInteger("2790"), ciphertext) + } + + @Test + fun testDecrypt() { + // from https://en.wikipedia.org/wiki/RSA_(cryptosystem)#Example + val public = allowUnsafeMod { RSAKeyParameters(true, BigInteger("3233"), BigInteger("413")) } + val ciphertext = Rsa.decrypt(BigInteger("2790"), public) + assertEquals(BigInteger("65"), ciphertext) + } + + @Test + fun testDecryptCrt() { + // from https://en.wikipedia.org/wiki/RSA_(cryptosystem)#Example + val private = allowUnsafeMod { RSAPrivateCrtKeyParameters( + BigInteger("3233"), // modulus + BigInteger("17"), // public exponent + BigInteger("413"), // private exponent + BigInteger("61"), // p + BigInteger("53"), // q + BigInteger("53"), // dP + BigInteger("49"), // dQ + BigInteger("38") // qInv + ) } + val ciphertext = Rsa.decrypt(BigInteger("2790"), private) + assertEquals(BigInteger("65"), ciphertext) + } + + private fun allowUnsafeMod(f: () -> T): T { + Properties.setThreadOverride(ALLOW_UNSAFE_MOD, true) + try { + return f() + } finally { + Properties.setThreadOverride(ALLOW_UNSAFE_MOD, false) + } + } +}