diff --git a/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt b/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt index 66bfc631..fd9d0c22 100644 --- a/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt +++ b/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt @@ -13,8 +13,8 @@ import dev.openrs2.deob.transform.BitShiftTransformer import dev.openrs2.deob.transform.BitwiseOpTransformer import dev.openrs2.deob.transform.CanvasTransformer import dev.openrs2.deob.transform.ClassLiteralTransformer +import dev.openrs2.deob.transform.ConstantArgTransformer import dev.openrs2.deob.transform.CounterTransformer -import dev.openrs2.deob.transform.DummyArgTransformer import dev.openrs2.deob.transform.DummyLocalTransformer import dev.openrs2.deob.transform.EmptyClassTransformer import dev.openrs2.deob.transform.ExceptionTracingTransformer @@ -155,7 +155,7 @@ class Deobfuscator(private val input: Path, private val output: Path) { FieldOrderTransformer(), BitwiseOpTransformer(), RemapTransformer(), - DummyArgTransformer(), + ConstantArgTransformer(), DummyLocalTransformer(), UnusedArgTransformer(), UnusedMethodTransformer(), diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt new file mode 100644 index 00000000..62f7b4d3 --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt @@ -0,0 +1,78 @@ +package dev.openrs2.deob.analysis + +import dev.openrs2.deob.analysis.IntBranchResult.Companion.fromTakenNotTaken +import org.objectweb.asm.Opcodes.IFEQ +import org.objectweb.asm.Opcodes.IFGE +import org.objectweb.asm.Opcodes.IFGT +import org.objectweb.asm.Opcodes.IFLE +import org.objectweb.asm.Opcodes.IFLT +import org.objectweb.asm.Opcodes.IFNE +import org.objectweb.asm.Opcodes.IF_ICMPEQ +import org.objectweb.asm.Opcodes.IF_ICMPGE +import org.objectweb.asm.Opcodes.IF_ICMPGT +import org.objectweb.asm.Opcodes.IF_ICMPLE +import org.objectweb.asm.Opcodes.IF_ICMPLT +import org.objectweb.asm.Opcodes.IF_ICMPNE + +object IntBranch { + fun evaluateUnary(opcode: Int, values: Set): IntBranchResult { + require(values.isNotEmpty()) + + var taken = 0 + var notTaken = 0 + + for (v in values) { + if (evaluateUnary(opcode, v)) { + taken++ + } else { + notTaken++ + } + } + + return fromTakenNotTaken(taken, notTaken) + } + + private fun evaluateUnary(opcode: Int, value: Int): Boolean { + return when (opcode) { + IFEQ -> value == 0 + IFNE -> value != 0 + IFLT -> value < 0 + IFGE -> value >= 0 + IFGT -> value > 0 + IFLE -> value <= 0 + else -> throw IllegalArgumentException("Opcode $opcode is not a unary conditional branch instruction") + } + } + + fun evaluateBinary(opcode: Int, values1: Set, values2: Set): IntBranchResult { + require(values1.isNotEmpty()) + require(values2.isNotEmpty()) + + var taken = 0 + var notTaken = 0 + + for (v1 in values1) { + for (v2 in values2) { + if (evaluateBinary(opcode, v1, v2)) { + taken++ + } else { + notTaken++ + } + } + } + + return fromTakenNotTaken(taken, notTaken) + } + + private fun evaluateBinary(opcode: Int, value1: Int, value2: Int): Boolean { + return when (opcode) { + IF_ICMPEQ -> value1 == value2 + IF_ICMPNE -> value1 != value2 + IF_ICMPLT -> value1 < value2 + IF_ICMPGE -> value1 >= value2 + IF_ICMPGT -> value1 > value2 + IF_ICMPLE -> value1 <= value2 + else -> throw IllegalArgumentException("Opcode $opcode is not a binary conditional branch instruction") + } + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt new file mode 100644 index 00000000..e6513151 --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt @@ -0,0 +1,17 @@ +package dev.openrs2.deob.analysis + +enum class IntBranchResult { + ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; + + companion object { + fun fromTakenNotTaken(taken: Int, notTaken: Int): IntBranchResult { + require(taken != 0 || notTaken != 0) + + return when { + taken == 0 -> NEVER_TAKEN + notTaken == 0 -> ALWAYS_TAKEN + else -> UNKNOWN + } + } + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt index 3843b2f1..920b12ee 100644 --- a/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt @@ -8,40 +8,33 @@ import org.objectweb.asm.tree.IincInsnNode import org.objectweb.asm.tree.analysis.BasicInterpreter import org.objectweb.asm.tree.analysis.Interpreter -class IntInterpreter(private val parameters: Array?>?) : Interpreter(Opcodes.ASM8) { +class IntInterpreter(private val args: Array) : Interpreter(Opcodes.ASM8) { private val basicInterpreter = BasicInterpreter() override fun newValue(type: Type?): IntValue? { val basicValue = basicInterpreter.newValue(type) ?: return null - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } override fun newParameterValue(isInstanceMethod: Boolean, local: Int, type: Type): IntValue { val basicValue = basicInterpreter.newParameterValue(isInstanceMethod, local, type) - if (parameters != null) { - val parameterIndex = when { - isInstanceMethod && local == 0 -> return IntValue.Unknown(basicValue) - isInstanceMethod -> local - 1 - else -> local - } - - val parameter = parameters[parameterIndex] - if (parameter != null) { - return IntValue.Constant(basicValue, parameter) - } + val index = when { + isInstanceMethod && local == 0 -> return IntValue(basicValue) + isInstanceMethod -> local - 1 + else -> local } - return IntValue.Unknown(basicValue) + return IntValue(basicValue, args[index]) } override fun newOperation(insn: AbstractInsnNode): IntValue { val basicValue = basicInterpreter.newOperation(insn) val v = insn.intConstant return if (v != null) { - IntValue.Constant(basicValue, v) + IntValue(basicValue, IntValueSet.singleton(v)) } else { - IntValue.Unknown(basicValue) + IntValue(basicValue) } } @@ -52,48 +45,48 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter() - for (v in value.values) { + for (v in value.set.values) { val result = when { insn.opcode == Opcodes.INEG -> -v insn is IincInsnNode -> v + insn.incr insn.opcode == Opcodes.I2B -> v.toByte().toInt() insn.opcode == Opcodes.I2C -> v.toChar().toInt() insn.opcode == Opcodes.I2S -> v.toShort().toInt() - else -> return IntValue.Unknown(basicValue) + else -> return IntValue(basicValue) } set.add(result) } - return IntValue.Constant(basicValue, set) + return IntValue(basicValue, IntValueSet.Constant(set)) } override fun binaryOperation(insn: AbstractInsnNode, value1: IntValue, value2: IntValue): IntValue? { val basicValue = basicInterpreter.binaryOperation(insn, value1.basicValue, value2.basicValue) ?: return null - if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) { - return IntValue.Unknown(basicValue) + if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) { + return IntValue(basicValue) } val set = mutableSetOf() - for (v1 in value1.values) { - for (v2 in value2.values) { + for (v1 in value1.set.values) { + for (v2 in value2.set.values) { val result = when (insn.opcode) { Opcodes.IADD -> v1 + v2 Opcodes.ISUB -> v1 - v2 Opcodes.IMUL -> v1 * v2 Opcodes.IDIV -> { if (v2 == 0) { - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } v1 / v2 } Opcodes.IREM -> { if (v2 == 0) { - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } v1 % v2 } @@ -103,12 +96,12 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter v1 and v2 Opcodes.IOR -> v1 or v2 Opcodes.IXOR -> v1 xor v2 - else -> return IntValue.Unknown(basicValue) + else -> return IntValue(basicValue) } set.add(result) } } - return IntValue.Constant(basicValue, set) + return IntValue(basicValue, IntValueSet.Constant(set)) } override fun ternaryOperation( @@ -120,13 +113,13 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter): IntValue? { val args = values.map(IntValue::basicValue) val basicValue = basicInterpreter.naryOperation(insn, args) ?: return null - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } override fun returnOperation(insn: AbstractInsnNode, value: IntValue, expected: IntValue) { @@ -140,15 +133,15 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter MAX_TRACKED_VALUES) { - IntValue.Unknown(basicValue) + val set = value1.set union value2.set + return if (set is IntValueSet.Constant && set.values.size > MAX_TRACKED_VALUES) { + IntValue(basicValue) } else { - IntValue.Constant(basicValue, set) + IntValue(basicValue, set) } } diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt index 8455b874..8f022ed3 100644 --- a/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt @@ -3,22 +3,7 @@ package dev.openrs2.deob.analysis import org.objectweb.asm.tree.analysis.BasicValue import org.objectweb.asm.tree.analysis.Value -sealed class IntValue : Value { - data class Unknown(override val basicValue: BasicValue) : IntValue() - data class Constant(override val basicValue: BasicValue, val values: Set) : IntValue() { - val singleton: Int? - - init { - require(values.isNotEmpty()) - - singleton = if (values.size == 1) values.first() else null - } - - constructor(basicValue: BasicValue, value: Int) : this(basicValue, setOf(value)) - } - - abstract val basicValue: BasicValue - +data class IntValue(val basicValue: BasicValue, val set: IntValueSet = IntValueSet.Unknown) : Value { override fun getSize(): Int { return basicValue.size } diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt new file mode 100644 index 00000000..a7a0b6a8 --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt @@ -0,0 +1,42 @@ +package dev.openrs2.deob.analysis + +sealed class IntValueSet { + data class Constant(val values: Set) : IntValueSet() { + init { + require(values.isNotEmpty()) + } + + override val singleton: Int? + get() = if (values.size == 1) { + values.first() + } else { + null + } + + override fun union(other: IntValueSet): IntValueSet { + return if (other is Constant) { + Constant(values union other.values) + } else { + Unknown + } + } + } + + object Unknown : IntValueSet() { + override val singleton: Int? + get() = null + + override fun union(other: IntValueSet): IntValueSet { + return Unknown + } + } + + abstract val singleton: Int? + abstract infix fun union(other: IntValueSet): IntValueSet + + companion object { + fun singleton(value: Int): IntValueSet { + return Constant(setOf(value)) + } + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt b/deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt deleted file mode 100644 index 43cd8acd..00000000 --- a/deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt +++ /dev/null @@ -1,6 +0,0 @@ -package dev.openrs2.deob.analysis - -import dev.openrs2.asm.MemberRef -import dev.openrs2.util.collect.DisjointSet - -data class SourcedIntValue(val source: DisjointSet.Partition, val intValue: IntValue) diff --git a/deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt b/deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt new file mode 100644 index 00000000..065b2ea2 --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt @@ -0,0 +1,310 @@ +package dev.openrs2.deob.transform + +import com.github.michaelbull.logging.InlineLogger +import dev.openrs2.asm.MemberRef +import dev.openrs2.asm.classpath.ClassPath +import dev.openrs2.asm.classpath.Library +import dev.openrs2.asm.createIntConstant +import dev.openrs2.asm.deleteExpression +import dev.openrs2.asm.hasCode +import dev.openrs2.asm.intConstant +import dev.openrs2.asm.nextReal +import dev.openrs2.asm.pure +import dev.openrs2.asm.replaceExpression +import dev.openrs2.asm.stackMetadata +import dev.openrs2.asm.transform.Transformer +import dev.openrs2.deob.ArgRef +import dev.openrs2.deob.analysis.IntBranch +import dev.openrs2.deob.analysis.IntBranchResult.ALWAYS_TAKEN +import dev.openrs2.deob.analysis.IntBranchResult.NEVER_TAKEN +import dev.openrs2.deob.analysis.IntInterpreter +import dev.openrs2.deob.analysis.IntValueSet +import dev.openrs2.deob.remap.TypedRemapper +import dev.openrs2.util.collect.DisjointSet +import dev.openrs2.util.collect.removeFirstOrNull +import org.objectweb.asm.Opcodes.GOTO +import org.objectweb.asm.Opcodes.IFEQ +import org.objectweb.asm.Opcodes.IFGE +import org.objectweb.asm.Opcodes.IFGT +import org.objectweb.asm.Opcodes.IFLE +import org.objectweb.asm.Opcodes.IFLT +import org.objectweb.asm.Opcodes.IFNE +import org.objectweb.asm.Opcodes.IF_ICMPEQ +import org.objectweb.asm.Opcodes.IF_ICMPGE +import org.objectweb.asm.Opcodes.IF_ICMPGT +import org.objectweb.asm.Opcodes.IF_ICMPLE +import org.objectweb.asm.Opcodes.IF_ICMPLT +import org.objectweb.asm.Opcodes.IF_ICMPNE +import org.objectweb.asm.Type +import org.objectweb.asm.tree.AbstractInsnNode +import org.objectweb.asm.tree.ClassNode +import org.objectweb.asm.tree.JumpInsnNode +import org.objectweb.asm.tree.MethodInsnNode +import org.objectweb.asm.tree.MethodNode +import org.objectweb.asm.tree.analysis.Analyzer + +class ConstantArgTransformer : Transformer() { + private val pendingMethods = LinkedHashSet() + private val arglessMethods = mutableSetOf>() + private val argValues = mutableMapOf() + private lateinit var inheritedMethodSets: DisjointSet + private var branchesSimplified = 0 + private var constantsInlined = 0 + + override fun preTransform(classPath: ClassPath) { + pendingMethods.clear() + arglessMethods.clear() + argValues.clear() + inheritedMethodSets = classPath.createInheritedMethodSets() + branchesSimplified = 0 + constantsInlined = 0 + + queueEntryPoints(classPath) + + while (true) { + val method = pendingMethods.removeFirstOrNull() ?: break + analyzeMethod(classPath, method) + } + } + + private fun queueEntryPoints(classPath: ClassPath) { + for (partition in inheritedMethodSets) { + /* + * The EXCLUDED_METHODS set roughly matches up with the methods we + * want to consider as entry points. It isn't perfect - it counts + * every method as an entry point, but strictly speaking we + * only need to count methods invoked with reflection as + * entry points (like VisibilityTransformer). However, it makes no + * difference in this case, as the deobfuscator does not add dummy + * constant arguments to constructors. + */ + val excluded = partition.first().name in TypedRemapper.EXCLUDED_METHODS + val overridesDependency = partition.any { classPath[it.owner]!!.dependency } + if (excluded || overridesDependency) { + pendingMethods.addAll(partition) + } + } + } + + private fun analyzeMethod(classPath: ClassPath, ref: MemberRef) { + // find ClassNode/MethodNode + val owner = classPath.getNode(ref.owner) ?: return + val originalMethod = owner.methods.singleOrNull { it.name == ref.name && it.desc == ref.desc } ?: return + if (!originalMethod.hasCode()) { + return + } + + /* + * Clone the method - we don't want to mutate it permanently until the + * final pass, as we might discover more routes through the call graph + * later which reduce the number of branches we can simplify. + */ + val method = MethodNode( + originalMethod.access, + originalMethod.name, + originalMethod.desc, + originalMethod.signature, + originalMethod.exceptions?.toTypedArray() + ) + originalMethod.accept(method) + + // find existing constant arguments + val args = getArgs(ref) + + // simplify branches + simplifyBranches(owner, method, args) + + /* + * Record new constant arguments in method calls. This re-runs the + * analyzer rather than re-using the frames from simplifyBranches. This + * ensures we ignore branches that always evaluate to false, preventing + * us from recording constant arguments found in dummy calls (which + * would prevent us from removing further dummy calls/branches). + */ + addArgValues(owner, method, args) + } + + private fun getArgs(ref: MemberRef): Array { + val partition = inheritedMethodSets[ref]!! + val size = Type.getArgumentTypes(ref.desc).sumBy { it.size } + return Array(size) { i -> argValues[ArgRef(partition, i)] ?: IntValueSet.Unknown } + } + + private fun addArgValues(owner: ClassNode, method: MethodNode, args: Array) { + val analyzer = Analyzer(IntInterpreter(args)) + val frames = analyzer.analyze(owner.name, method) + for ((i, frame) in frames.withIndex()) { + if (frame == null) { + continue + } + + val insn = method.instructions[i] + if (insn !is MethodInsnNode) { + continue + } + + val invokedMethod = inheritedMethodSets[MemberRef(insn)] ?: continue + val size = Type.getArgumentTypes(insn.desc).size + + var index = 0 + for (j in 0 until size) { + val value = frame.getStack(frame.stackSize - size + j) + if (addArgValues(ArgRef(invokedMethod, index), value.set)) { + pendingMethods.addAll(invokedMethod) + } + index += value.size + } + + if (size == 0 && arglessMethods.add(invokedMethod)) { + pendingMethods.addAll(invokedMethod) + } + } + } + + private fun addArgValues(ref: ArgRef, value: IntValueSet): Boolean { + val old = argValues[ref] + + val new = if (value.singleton != null) { + if (old != null) { + old union value + } else { + value + } + } else { + IntValueSet.Unknown + } + argValues[ref] = new + + return old != new + } + + private fun simplifyBranches(owner: ClassNode, method: MethodNode, args: Array): Int { + val analyzer = Analyzer(IntInterpreter(args)) + val frames = analyzer.analyze(owner.name, method) + + val alwaysTakenBranches = mutableListOf() + val neverTakenBranches = mutableListOf() + + frame@ for ((i, frame) in frames.withIndex()) { + if (frame == null) { + continue + } + + val insn = method.instructions[i] + if (insn !is JumpInsnNode) { + continue + } + + when (insn.opcode) { + IFEQ, IFNE, IFLT, IFGE, IFGT, IFLE -> { + val value = frame.getStack(frame.stackSize - 1) + if (value.set !is IntValueSet.Constant) { + continue@frame + } + + @Suppress("NON_EXHAUSTIVE_WHEN") + when (IntBranch.evaluateUnary(insn.opcode, value.set.values)) { + ALWAYS_TAKEN -> alwaysTakenBranches += insn + NEVER_TAKEN -> neverTakenBranches += insn + } + } + IF_ICMPEQ, IF_ICMPNE, IF_ICMPLT, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE -> { + val value1 = frame.getStack(frame.stackSize - 2) + val value2 = frame.getStack(frame.stackSize - 1) + if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) { + continue@frame + } + + @Suppress("NON_EXHAUSTIVE_WHEN") + when (IntBranch.evaluateBinary(insn.opcode, value1.set.values, value2.set.values)) { + ALWAYS_TAKEN -> alwaysTakenBranches += insn + NEVER_TAKEN -> neverTakenBranches += insn + } + } + } + } + + var simplified = 0 + + for (insn in alwaysTakenBranches) { + val replacement = JumpInsnNode(GOTO, insn.label) + if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { + simplified++ + } + } + + for (insn in neverTakenBranches) { + if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) { + simplified++ + } + } + + return simplified + } + + private fun inlineConstantArgs(clazz: ClassNode, method: MethodNode, args: Array): Int { + val analyzer = Analyzer(IntInterpreter(args)) + val frames = analyzer.analyze(clazz.name, method) + + val constInsns = mutableMapOf() + + for ((i, frame) in frames.withIndex()) { + if (frame == null) { + continue + } + + val insn = method.instructions[i] + if (insn.intConstant != null) { + // already constant + continue + } else if (!insn.pure) { + // can't replace instructions with a side effect + continue + } else if (insn.stackMetadata().pushes != 1) { + // can't replace instructions pushing more than one result + continue + } + + // the value pushed by this instruction is held in the following frame + val nextInsn = insn.nextReal ?: continue + val nextInsnIndex = method.instructions.indexOf(nextInsn) + val nextFrame = frames[nextInsnIndex] + + val value = nextFrame.getStack(nextFrame.stackSize - 1) + val singleton = value.set.singleton + if (singleton != null) { + constInsns[insn] = singleton + } + } + + var inlined = 0 + + for ((insn, value) in constInsns) { + if (insn !in method.instructions) { + continue + } + + val replacement = createIntConstant(value) + if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { + inlined++ + } + } + + return inlined + } + + override fun transformCode(classPath: ClassPath, library: Library, clazz: ClassNode, method: MethodNode): Boolean { + val args = getArgs(MemberRef(clazz, method)) + branchesSimplified += simplifyBranches(clazz, method, args) + constantsInlined += inlineConstantArgs(clazz, method, args) + return false + } + + override fun postTransform(classPath: ClassPath) { + logger.info { "Simplified $branchesSimplified branches and inlined $constantsInlined constants" } + } + + companion object { + private val logger = InlineLogger() + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt b/deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt deleted file mode 100644 index 110257a4..00000000 --- a/deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt +++ /dev/null @@ -1,429 +0,0 @@ -package dev.openrs2.deob.transform - -import com.github.michaelbull.logging.InlineLogger -import com.google.common.collect.HashMultimap -import com.google.common.collect.Multimap -import dev.openrs2.asm.InsnMatcher -import dev.openrs2.asm.MemberRef -import dev.openrs2.asm.classpath.ClassPath -import dev.openrs2.asm.classpath.Library -import dev.openrs2.asm.createIntConstant -import dev.openrs2.asm.deleteExpression -import dev.openrs2.asm.intConstant -import dev.openrs2.asm.nextReal -import dev.openrs2.asm.pure -import dev.openrs2.asm.replaceExpression -import dev.openrs2.asm.stackMetadata -import dev.openrs2.asm.transform.Transformer -import dev.openrs2.deob.ArgRef -import dev.openrs2.deob.analysis.IntInterpreter -import dev.openrs2.deob.analysis.IntValue -import dev.openrs2.deob.analysis.SourcedIntValue -import dev.openrs2.util.collect.DisjointSet -import org.objectweb.asm.Opcodes -import org.objectweb.asm.Type -import org.objectweb.asm.tree.AbstractInsnNode -import org.objectweb.asm.tree.ClassNode -import org.objectweb.asm.tree.JumpInsnNode -import org.objectweb.asm.tree.MethodInsnNode -import org.objectweb.asm.tree.MethodNode -import org.objectweb.asm.tree.VarInsnNode -import org.objectweb.asm.tree.analysis.Analyzer - -class DummyArgTransformer : Transformer() { - private data class ConditionalCall( - val conditionVar: Int, - val conditionOpcode: Int, - val conditionValue: Int?, - val method: DisjointSet.Partition, - val constArgs: List - ) - - private enum class BranchResult { - ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; - - companion object { - fun fromTakenNotTaken(taken: Int, notTaken: Int): BranchResult { - require(taken != 0 || notTaken != 0) - - return when { - taken == 0 -> NEVER_TAKEN - notTaken == 0 -> ALWAYS_TAKEN - else -> UNKNOWN - } - } - } - } - - private val argValues: Multimap = HashMultimap.create() - private val conditionalCalls: Multimap?, ConditionalCall> = HashMultimap.create() - private val constArgs = mutableMapOf, Array?>>() - private lateinit var inheritedMethodSets: DisjointSet - private var branchesSimplified = 0 - private var constantsInlined = 0 - - private fun isMutuallyRecursiveDummy( - method: DisjointSet.Partition, - arg: Int, - source: DisjointSet.Partition, - value: Int - ): Boolean { - for (sourceToMethodCall in conditionalCalls[source]) { - if (sourceToMethodCall.method != method) { - continue - } - - for (methodToSourceCall in conditionalCalls[method]) { - if (methodToSourceCall.method != source || methodToSourceCall.conditionVar != arg) { - continue - } - - var taken = if (methodToSourceCall.conditionValue != null) { - evaluateBinaryBranch(methodToSourceCall.conditionOpcode, value, methodToSourceCall.conditionValue) - } else { - evaluateUnaryBranch(methodToSourceCall.conditionOpcode, value) - } - - if (taken) { - continue - } - - val constArg = methodToSourceCall.constArgs[sourceToMethodCall.conditionVar]!! - - taken = if (sourceToMethodCall.conditionValue != null) { - evaluateBinaryBranch( - sourceToMethodCall.conditionOpcode, - constArg, - sourceToMethodCall.conditionValue - ) - } else { - evaluateUnaryBranch(sourceToMethodCall.conditionOpcode, constArg) - } - - if (taken) { - continue - } - - return true - } - } - - return false - } - - private fun union( - method: DisjointSet.Partition, - arg: Int, - intValues: Collection - ): Set? { - val set = mutableSetOf() - - for ((source, intValue) in intValues) { - if (intValue !is IntValue.Constant) { - return null - } - - if (source == method) { - continue - } - - if (intValue.singleton != null) { - if (isMutuallyRecursiveDummy(method, arg, source, intValue.singleton)) { - continue - } - } - - set.addAll(intValue.values) - } - - return if (set.isEmpty()) { - null - } else { - set - } - } - - override fun preTransform(classPath: ClassPath) { - inheritedMethodSets = classPath.createInheritedMethodSets() - branchesSimplified = 0 - constantsInlined = 0 - } - - override fun prePass(classPath: ClassPath): Boolean { - argValues.clear() - conditionalCalls.clear() - return false - } - - override fun transformCode( - classPath: ClassPath, - library: Library, - clazz: ClassNode, - method: MethodNode - ): Boolean { - val parentMethod = inheritedMethodSets[MemberRef(clazz, method)]!! - - val stores = BooleanArray(method.maxLocals) - for (insn in method.instructions) { - if (insn is VarInsnNode && insn.opcode == Opcodes.ISTORE) { - stores[insn.`var`] = true - } - } - - for (match in CONDITIONAL_CALL_MATCHER.match(method)) { - var matchIndex = 0 - - val load = match[matchIndex++] as VarInsnNode - if (stores[load.`var`]) { - continue - } - - var callerSlots = Type.getArgumentsAndReturnSizes(method.desc) shr 2 - if (method.access and Opcodes.ACC_STATIC != 0) { - callerSlots-- - } - if (load.`var` >= callerSlots) { - continue - } - - val conditionValue: Int? - var conditionOpcode = match[matchIndex].opcode - if (conditionOpcode == Opcodes.IFEQ || conditionOpcode == Opcodes.IFNE) { - conditionValue = null - matchIndex++ - } else { - conditionValue = match[matchIndex++].intConstant - conditionOpcode = match[matchIndex++].opcode - } - - val invoke = match[match.size - 1] as MethodInsnNode - var invokeArgCount = Type.getArgumentTypes(invoke.desc).size - if (invoke.opcode != Opcodes.INVOKESTATIC) { - invokeArgCount++ - } - - val constArgs = arrayOfNulls(invokeArgCount) - for (i in constArgs.indices) { - val insn = match[matchIndex++] - if (insn.opcode == Opcodes.ACONST_NULL) { - matchIndex++ - } else { - constArgs[i] = insn.intConstant - } - } - - val callee = inheritedMethodSets[MemberRef(invoke)] ?: continue - conditionalCalls.put( - parentMethod, - ConditionalCall(load.`var`, conditionOpcode, conditionValue, callee, constArgs.asList()) - ) - } - - val parameters = constArgs[parentMethod] - val analyzer = Analyzer(IntInterpreter(parameters)) - val frames = analyzer.analyze(clazz.name, method) - - var changed = false - val alwaysTakenBranches = mutableListOf() - val neverTakenBranches = mutableListOf() - val constInsns = mutableMapOf() - - frame@ for ((i, frame) in frames.withIndex()) { - if (frame == null) { - continue - } - - val stackSize = frame.stackSize - - val insn = method.instructions[i] - when (insn.opcode) { - Opcodes.INVOKEVIRTUAL, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEINTERFACE -> { - val invoke = insn as MethodInsnNode - val invokedMethod = inheritedMethodSets[MemberRef(invoke)] ?: continue@frame - val args = Type.getArgumentTypes(invoke.desc).size - - var k = 0 - for (j in 0 until args) { - val arg = frame.getStack(stackSize - args + j) - argValues.put(ArgRef(invokedMethod, k), SourcedIntValue(parentMethod, arg)) - k += arg.size - } - } - Opcodes.IFEQ, Opcodes.IFNE -> { - val value = frame.getStack(stackSize - 1) - if (value !is IntValue.Constant) { - continue@frame - } - - val result = evaluateUnaryBranch(insn.opcode, value.values) - @Suppress("NON_EXHAUSTIVE_WHEN") - when (result) { - BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) - BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) - } - } - Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPLT, Opcodes.IF_ICMPGE, Opcodes.IF_ICMPGT, - Opcodes.IF_ICMPLE -> { - val value1 = frame.getStack(stackSize - 2) - val value2 = frame.getStack(stackSize - 1) - if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) { - continue@frame - } - - val result = evaluateBinaryBranch(insn.opcode, value1.values, value2.values) - @Suppress("NON_EXHAUSTIVE_WHEN") - when (result) { - BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) - BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) - } - } - else -> { - if (!insn.pure || insn.intConstant != null) { - continue@frame - } - - if (insn.stackMetadata().pushes != 1) { - continue@frame - } - - val nextInsn = insn.nextReal ?: continue@frame - val nextInsnIndex = method.instructions.indexOf(nextInsn) - val nextFrame = frames[nextInsnIndex] - - val value = nextFrame.getStack(nextFrame.stackSize - 1) - if (value is IntValue.Constant && value.singleton != null) { - constInsns[insn] = value.singleton - } - } - } - } - - for (insn in alwaysTakenBranches) { - val replacement = JumpInsnNode(Opcodes.GOTO, insn.label) - if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { - branchesSimplified++ - changed = true - } - } - - for (insn in neverTakenBranches) { - if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) { - branchesSimplified++ - changed = true - } - } - - for ((insn, value) in constInsns) { - if (insn !in method.instructions) { - continue - } - - val replacement = createIntConstant(value) - if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { - constantsInlined++ - changed = true - } - } - - return changed - } - - override fun postPass(classPath: ClassPath): Boolean { - for (method in inheritedMethodSets) { - val args = (Type.getArgumentsAndReturnSizes(method.first().desc) shr 2) - 1 - - var allUnknown = true - val parameters = arrayOfNulls?>(args) - - for (i in 0 until args) { - val parameter = union(method, i, argValues[ArgRef(method, i)]) - if (parameter != null) { - allUnknown = false - } - parameters[i] = parameter - } - - if (allUnknown) { - constArgs.remove(method) - } else { - constArgs[method] = parameters - } - } - - return false - } - - override fun postTransform(classPath: ClassPath) { - logger.info { "Simplified $branchesSimplified dummy branches and inlined $constantsInlined constants" } - } - - companion object { - private val logger = InlineLogger() - private val CONDITIONAL_CALL_MATCHER = InsnMatcher.compile( - """ - ILOAD - (IFEQ | IFNE | - (ICONST | BIPUSH | SIPUSH | LDC) - (IF_ICMPEQ | IF_ICMPNE | IF_ICMPLT | IF_ICMPGE | IF_ICMPGT | IF_ICMPLE) - ) - ALOAD? - (ICONST | FCONST | DCONST | BIPUSH | SIPUSH | LDC | ACONST_NULL CHECKCAST)+ - (INVOKEVIRTUAL | INVOKESTATIC | INVOKEINTERFACE) - """ - ) - - private fun evaluateUnaryBranch(opcode: Int, values: Set): BranchResult { - require(values.isNotEmpty()) - - var taken = 0 - var notTaken = 0 - for (v in values) { - if (evaluateUnaryBranch(opcode, v)) { - taken++ - } else { - notTaken++ - } - } - - return BranchResult.fromTakenNotTaken(taken, notTaken) - } - - private fun evaluateUnaryBranch(opcode: Int, value: Int): Boolean { - return when (opcode) { - Opcodes.IFEQ -> value == 0 - Opcodes.IFNE -> value != 0 - else -> throw IllegalArgumentException() - } - } - - private fun evaluateBinaryBranch(opcode: Int, values1: Set, values2: Set): BranchResult { - require(values1.isNotEmpty() && values2.isNotEmpty()) - - var taken = 0 - var notTaken = 0 - for (v1 in values1) { - for (v2 in values2) { - if (evaluateBinaryBranch(opcode, v1, v2)) { - taken++ - } else { - notTaken++ - } - } - } - - return BranchResult.fromTakenNotTaken(taken, notTaken) - } - - private fun evaluateBinaryBranch(opcode: Int, value1: Int, value2: Int): Boolean { - return when (opcode) { - Opcodes.IF_ICMPEQ -> value1 == value2 - Opcodes.IF_ICMPNE -> value1 != value2 - Opcodes.IF_ICMPLT -> value1 < value2 - Opcodes.IF_ICMPGE -> value1 >= value2 - Opcodes.IF_ICMPGT -> value1 > value2 - Opcodes.IF_ICMPLE -> value1 <= value2 - else -> throw IllegalArgumentException() - } - } - } -}