From ff594848d56a5a05b917c91ed5ecc1a4ce4d1d8c Mon Sep 17 00:00:00 2001 From: Graham Date: Thu, 9 Apr 2020 21:27:57 +0100 Subject: [PATCH] Replace DummyArgTransformer with new ConstantArgTransformer The new transformer uses a different approach to the old one. It starts exploring the call graph from the entry points, recursively analysing method calls. Methods are only re-analysed if their possible argument values change, with the Unknown value being used if we can't identify a single integer constant at a call site. This prevents us from recursing infinitely if the client code does. While this first pass does simplify branches in order to ignore dummy method calls that are never evaluated at runtime, it operates on a copy of the method (as we initially ignore more calls while the argument value sets are smaller, ignoring fewer calls as they build up). A separate second pass simplifies branches on the original method and inlines singleton constants, paving the way for the UnusedArgTransformer to actually remove the newly unused arguments. This new approach has several benefits: - It is much faster than the old approach, as we only re-analyse methods as required by argument value changes, rather than re-analysing every method during every pass. - It doesn't require special cases for dealing with mutually recursive dummy calls. The old approach hard-coded special cases for mutually recursive calls involving groups of 1 and 2 methods. The code for this wasn't clean. Furthermore, while it was just about good enough for the HD client, the SD client contains a mutually recursive group of 3 methods. The new approach is capable of dealing with mutually recursive groups of any size. Finally, the new transformer has a much cleaner implementation. Signed-off-by: Graham --- .../java/dev/openrs2/deob/Deobfuscator.kt | 4 +- .../dev/openrs2/deob/analysis/IntBranch.kt | 78 ++++ .../openrs2/deob/analysis/IntBranchResult.kt | 17 + .../openrs2/deob/analysis/IntInterpreter.kt | 67 ++- .../dev/openrs2/deob/analysis/IntValue.kt | 17 +- .../dev/openrs2/deob/analysis/IntValueSet.kt | 42 ++ .../openrs2/deob/analysis/SourcedIntValue.kt | 6 - .../deob/transform/ConstantArgTransformer.kt | 310 +++++++++++++ .../deob/transform/DummyArgTransformer.kt | 429 ------------------ 9 files changed, 480 insertions(+), 490 deletions(-) create mode 100644 deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt create mode 100644 deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt create mode 100644 deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt delete mode 100644 deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt create mode 100644 deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt delete mode 100644 deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt diff --git a/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt b/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt index 66bfc63129..fd9d0c2227 100644 --- a/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt +++ b/deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt @@ -13,8 +13,8 @@ import dev.openrs2.deob.transform.BitShiftTransformer import dev.openrs2.deob.transform.BitwiseOpTransformer import dev.openrs2.deob.transform.CanvasTransformer import dev.openrs2.deob.transform.ClassLiteralTransformer +import dev.openrs2.deob.transform.ConstantArgTransformer import dev.openrs2.deob.transform.CounterTransformer -import dev.openrs2.deob.transform.DummyArgTransformer import dev.openrs2.deob.transform.DummyLocalTransformer import dev.openrs2.deob.transform.EmptyClassTransformer import dev.openrs2.deob.transform.ExceptionTracingTransformer @@ -155,7 +155,7 @@ class Deobfuscator(private val input: Path, private val output: Path) { FieldOrderTransformer(), BitwiseOpTransformer(), RemapTransformer(), - DummyArgTransformer(), + ConstantArgTransformer(), DummyLocalTransformer(), UnusedArgTransformer(), UnusedMethodTransformer(), diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt new file mode 100644 index 0000000000..62f7b4d3ed --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt @@ -0,0 +1,78 @@ +package dev.openrs2.deob.analysis + +import dev.openrs2.deob.analysis.IntBranchResult.Companion.fromTakenNotTaken +import org.objectweb.asm.Opcodes.IFEQ +import org.objectweb.asm.Opcodes.IFGE +import org.objectweb.asm.Opcodes.IFGT +import org.objectweb.asm.Opcodes.IFLE +import org.objectweb.asm.Opcodes.IFLT +import org.objectweb.asm.Opcodes.IFNE +import org.objectweb.asm.Opcodes.IF_ICMPEQ +import org.objectweb.asm.Opcodes.IF_ICMPGE +import org.objectweb.asm.Opcodes.IF_ICMPGT +import org.objectweb.asm.Opcodes.IF_ICMPLE +import org.objectweb.asm.Opcodes.IF_ICMPLT +import org.objectweb.asm.Opcodes.IF_ICMPNE + +object IntBranch { + fun evaluateUnary(opcode: Int, values: Set): IntBranchResult { + require(values.isNotEmpty()) + + var taken = 0 + var notTaken = 0 + + for (v in values) { + if (evaluateUnary(opcode, v)) { + taken++ + } else { + notTaken++ + } + } + + return fromTakenNotTaken(taken, notTaken) + } + + private fun evaluateUnary(opcode: Int, value: Int): Boolean { + return when (opcode) { + IFEQ -> value == 0 + IFNE -> value != 0 + IFLT -> value < 0 + IFGE -> value >= 0 + IFGT -> value > 0 + IFLE -> value <= 0 + else -> throw IllegalArgumentException("Opcode $opcode is not a unary conditional branch instruction") + } + } + + fun evaluateBinary(opcode: Int, values1: Set, values2: Set): IntBranchResult { + require(values1.isNotEmpty()) + require(values2.isNotEmpty()) + + var taken = 0 + var notTaken = 0 + + for (v1 in values1) { + for (v2 in values2) { + if (evaluateBinary(opcode, v1, v2)) { + taken++ + } else { + notTaken++ + } + } + } + + return fromTakenNotTaken(taken, notTaken) + } + + private fun evaluateBinary(opcode: Int, value1: Int, value2: Int): Boolean { + return when (opcode) { + IF_ICMPEQ -> value1 == value2 + IF_ICMPNE -> value1 != value2 + IF_ICMPLT -> value1 < value2 + IF_ICMPGE -> value1 >= value2 + IF_ICMPGT -> value1 > value2 + IF_ICMPLE -> value1 <= value2 + else -> throw IllegalArgumentException("Opcode $opcode is not a binary conditional branch instruction") + } + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt new file mode 100644 index 0000000000..e6513151d6 --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt @@ -0,0 +1,17 @@ +package dev.openrs2.deob.analysis + +enum class IntBranchResult { + ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; + + companion object { + fun fromTakenNotTaken(taken: Int, notTaken: Int): IntBranchResult { + require(taken != 0 || notTaken != 0) + + return when { + taken == 0 -> NEVER_TAKEN + notTaken == 0 -> ALWAYS_TAKEN + else -> UNKNOWN + } + } + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt index 3843b2f190..920b12eecc 100644 --- a/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt @@ -8,40 +8,33 @@ import org.objectweb.asm.tree.IincInsnNode import org.objectweb.asm.tree.analysis.BasicInterpreter import org.objectweb.asm.tree.analysis.Interpreter -class IntInterpreter(private val parameters: Array?>?) : Interpreter(Opcodes.ASM8) { +class IntInterpreter(private val args: Array) : Interpreter(Opcodes.ASM8) { private val basicInterpreter = BasicInterpreter() override fun newValue(type: Type?): IntValue? { val basicValue = basicInterpreter.newValue(type) ?: return null - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } override fun newParameterValue(isInstanceMethod: Boolean, local: Int, type: Type): IntValue { val basicValue = basicInterpreter.newParameterValue(isInstanceMethod, local, type) - if (parameters != null) { - val parameterIndex = when { - isInstanceMethod && local == 0 -> return IntValue.Unknown(basicValue) - isInstanceMethod -> local - 1 - else -> local - } - - val parameter = parameters[parameterIndex] - if (parameter != null) { - return IntValue.Constant(basicValue, parameter) - } + val index = when { + isInstanceMethod && local == 0 -> return IntValue(basicValue) + isInstanceMethod -> local - 1 + else -> local } - return IntValue.Unknown(basicValue) + return IntValue(basicValue, args[index]) } override fun newOperation(insn: AbstractInsnNode): IntValue { val basicValue = basicInterpreter.newOperation(insn) val v = insn.intConstant return if (v != null) { - IntValue.Constant(basicValue, v) + IntValue(basicValue, IntValueSet.singleton(v)) } else { - IntValue.Unknown(basicValue) + IntValue(basicValue) } } @@ -52,48 +45,48 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter() - for (v in value.values) { + for (v in value.set.values) { val result = when { insn.opcode == Opcodes.INEG -> -v insn is IincInsnNode -> v + insn.incr insn.opcode == Opcodes.I2B -> v.toByte().toInt() insn.opcode == Opcodes.I2C -> v.toChar().toInt() insn.opcode == Opcodes.I2S -> v.toShort().toInt() - else -> return IntValue.Unknown(basicValue) + else -> return IntValue(basicValue) } set.add(result) } - return IntValue.Constant(basicValue, set) + return IntValue(basicValue, IntValueSet.Constant(set)) } override fun binaryOperation(insn: AbstractInsnNode, value1: IntValue, value2: IntValue): IntValue? { val basicValue = basicInterpreter.binaryOperation(insn, value1.basicValue, value2.basicValue) ?: return null - if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) { - return IntValue.Unknown(basicValue) + if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) { + return IntValue(basicValue) } val set = mutableSetOf() - for (v1 in value1.values) { - for (v2 in value2.values) { + for (v1 in value1.set.values) { + for (v2 in value2.set.values) { val result = when (insn.opcode) { Opcodes.IADD -> v1 + v2 Opcodes.ISUB -> v1 - v2 Opcodes.IMUL -> v1 * v2 Opcodes.IDIV -> { if (v2 == 0) { - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } v1 / v2 } Opcodes.IREM -> { if (v2 == 0) { - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } v1 % v2 } @@ -103,12 +96,12 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter v1 and v2 Opcodes.IOR -> v1 or v2 Opcodes.IXOR -> v1 xor v2 - else -> return IntValue.Unknown(basicValue) + else -> return IntValue(basicValue) } set.add(result) } } - return IntValue.Constant(basicValue, set) + return IntValue(basicValue, IntValueSet.Constant(set)) } override fun ternaryOperation( @@ -120,13 +113,13 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter): IntValue? { val args = values.map(IntValue::basicValue) val basicValue = basicInterpreter.naryOperation(insn, args) ?: return null - return IntValue.Unknown(basicValue) + return IntValue(basicValue) } override fun returnOperation(insn: AbstractInsnNode, value: IntValue, expected: IntValue) { @@ -140,15 +133,15 @@ class IntInterpreter(private val parameters: Array?>?) : Interpreter MAX_TRACKED_VALUES) { - IntValue.Unknown(basicValue) + val set = value1.set union value2.set + return if (set is IntValueSet.Constant && set.values.size > MAX_TRACKED_VALUES) { + IntValue(basicValue) } else { - IntValue.Constant(basicValue, set) + IntValue(basicValue, set) } } diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt index 8455b874aa..8f022ed3a0 100644 --- a/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt @@ -3,22 +3,7 @@ package dev.openrs2.deob.analysis import org.objectweb.asm.tree.analysis.BasicValue import org.objectweb.asm.tree.analysis.Value -sealed class IntValue : Value { - data class Unknown(override val basicValue: BasicValue) : IntValue() - data class Constant(override val basicValue: BasicValue, val values: Set) : IntValue() { - val singleton: Int? - - init { - require(values.isNotEmpty()) - - singleton = if (values.size == 1) values.first() else null - } - - constructor(basicValue: BasicValue, value: Int) : this(basicValue, setOf(value)) - } - - abstract val basicValue: BasicValue - +data class IntValue(val basicValue: BasicValue, val set: IntValueSet = IntValueSet.Unknown) : Value { override fun getSize(): Int { return basicValue.size } diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt b/deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt new file mode 100644 index 0000000000..a7a0b6a84d --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt @@ -0,0 +1,42 @@ +package dev.openrs2.deob.analysis + +sealed class IntValueSet { + data class Constant(val values: Set) : IntValueSet() { + init { + require(values.isNotEmpty()) + } + + override val singleton: Int? + get() = if (values.size == 1) { + values.first() + } else { + null + } + + override fun union(other: IntValueSet): IntValueSet { + return if (other is Constant) { + Constant(values union other.values) + } else { + Unknown + } + } + } + + object Unknown : IntValueSet() { + override val singleton: Int? + get() = null + + override fun union(other: IntValueSet): IntValueSet { + return Unknown + } + } + + abstract val singleton: Int? + abstract infix fun union(other: IntValueSet): IntValueSet + + companion object { + fun singleton(value: Int): IntValueSet { + return Constant(setOf(value)) + } + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt b/deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt deleted file mode 100644 index 43cd8acda8..0000000000 --- a/deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt +++ /dev/null @@ -1,6 +0,0 @@ -package dev.openrs2.deob.analysis - -import dev.openrs2.asm.MemberRef -import dev.openrs2.util.collect.DisjointSet - -data class SourcedIntValue(val source: DisjointSet.Partition, val intValue: IntValue) diff --git a/deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt b/deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt new file mode 100644 index 0000000000..065b2ea204 --- /dev/null +++ b/deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt @@ -0,0 +1,310 @@ +package dev.openrs2.deob.transform + +import com.github.michaelbull.logging.InlineLogger +import dev.openrs2.asm.MemberRef +import dev.openrs2.asm.classpath.ClassPath +import dev.openrs2.asm.classpath.Library +import dev.openrs2.asm.createIntConstant +import dev.openrs2.asm.deleteExpression +import dev.openrs2.asm.hasCode +import dev.openrs2.asm.intConstant +import dev.openrs2.asm.nextReal +import dev.openrs2.asm.pure +import dev.openrs2.asm.replaceExpression +import dev.openrs2.asm.stackMetadata +import dev.openrs2.asm.transform.Transformer +import dev.openrs2.deob.ArgRef +import dev.openrs2.deob.analysis.IntBranch +import dev.openrs2.deob.analysis.IntBranchResult.ALWAYS_TAKEN +import dev.openrs2.deob.analysis.IntBranchResult.NEVER_TAKEN +import dev.openrs2.deob.analysis.IntInterpreter +import dev.openrs2.deob.analysis.IntValueSet +import dev.openrs2.deob.remap.TypedRemapper +import dev.openrs2.util.collect.DisjointSet +import dev.openrs2.util.collect.removeFirstOrNull +import org.objectweb.asm.Opcodes.GOTO +import org.objectweb.asm.Opcodes.IFEQ +import org.objectweb.asm.Opcodes.IFGE +import org.objectweb.asm.Opcodes.IFGT +import org.objectweb.asm.Opcodes.IFLE +import org.objectweb.asm.Opcodes.IFLT +import org.objectweb.asm.Opcodes.IFNE +import org.objectweb.asm.Opcodes.IF_ICMPEQ +import org.objectweb.asm.Opcodes.IF_ICMPGE +import org.objectweb.asm.Opcodes.IF_ICMPGT +import org.objectweb.asm.Opcodes.IF_ICMPLE +import org.objectweb.asm.Opcodes.IF_ICMPLT +import org.objectweb.asm.Opcodes.IF_ICMPNE +import org.objectweb.asm.Type +import org.objectweb.asm.tree.AbstractInsnNode +import org.objectweb.asm.tree.ClassNode +import org.objectweb.asm.tree.JumpInsnNode +import org.objectweb.asm.tree.MethodInsnNode +import org.objectweb.asm.tree.MethodNode +import org.objectweb.asm.tree.analysis.Analyzer + +class ConstantArgTransformer : Transformer() { + private val pendingMethods = LinkedHashSet() + private val arglessMethods = mutableSetOf>() + private val argValues = mutableMapOf() + private lateinit var inheritedMethodSets: DisjointSet + private var branchesSimplified = 0 + private var constantsInlined = 0 + + override fun preTransform(classPath: ClassPath) { + pendingMethods.clear() + arglessMethods.clear() + argValues.clear() + inheritedMethodSets = classPath.createInheritedMethodSets() + branchesSimplified = 0 + constantsInlined = 0 + + queueEntryPoints(classPath) + + while (true) { + val method = pendingMethods.removeFirstOrNull() ?: break + analyzeMethod(classPath, method) + } + } + + private fun queueEntryPoints(classPath: ClassPath) { + for (partition in inheritedMethodSets) { + /* + * The EXCLUDED_METHODS set roughly matches up with the methods we + * want to consider as entry points. It isn't perfect - it counts + * every method as an entry point, but strictly speaking we + * only need to count methods invoked with reflection as + * entry points (like VisibilityTransformer). However, it makes no + * difference in this case, as the deobfuscator does not add dummy + * constant arguments to constructors. + */ + val excluded = partition.first().name in TypedRemapper.EXCLUDED_METHODS + val overridesDependency = partition.any { classPath[it.owner]!!.dependency } + if (excluded || overridesDependency) { + pendingMethods.addAll(partition) + } + } + } + + private fun analyzeMethod(classPath: ClassPath, ref: MemberRef) { + // find ClassNode/MethodNode + val owner = classPath.getNode(ref.owner) ?: return + val originalMethod = owner.methods.singleOrNull { it.name == ref.name && it.desc == ref.desc } ?: return + if (!originalMethod.hasCode()) { + return + } + + /* + * Clone the method - we don't want to mutate it permanently until the + * final pass, as we might discover more routes through the call graph + * later which reduce the number of branches we can simplify. + */ + val method = MethodNode( + originalMethod.access, + originalMethod.name, + originalMethod.desc, + originalMethod.signature, + originalMethod.exceptions?.toTypedArray() + ) + originalMethod.accept(method) + + // find existing constant arguments + val args = getArgs(ref) + + // simplify branches + simplifyBranches(owner, method, args) + + /* + * Record new constant arguments in method calls. This re-runs the + * analyzer rather than re-using the frames from simplifyBranches. This + * ensures we ignore branches that always evaluate to false, preventing + * us from recording constant arguments found in dummy calls (which + * would prevent us from removing further dummy calls/branches). + */ + addArgValues(owner, method, args) + } + + private fun getArgs(ref: MemberRef): Array { + val partition = inheritedMethodSets[ref]!! + val size = Type.getArgumentTypes(ref.desc).sumBy { it.size } + return Array(size) { i -> argValues[ArgRef(partition, i)] ?: IntValueSet.Unknown } + } + + private fun addArgValues(owner: ClassNode, method: MethodNode, args: Array) { + val analyzer = Analyzer(IntInterpreter(args)) + val frames = analyzer.analyze(owner.name, method) + for ((i, frame) in frames.withIndex()) { + if (frame == null) { + continue + } + + val insn = method.instructions[i] + if (insn !is MethodInsnNode) { + continue + } + + val invokedMethod = inheritedMethodSets[MemberRef(insn)] ?: continue + val size = Type.getArgumentTypes(insn.desc).size + + var index = 0 + for (j in 0 until size) { + val value = frame.getStack(frame.stackSize - size + j) + if (addArgValues(ArgRef(invokedMethod, index), value.set)) { + pendingMethods.addAll(invokedMethod) + } + index += value.size + } + + if (size == 0 && arglessMethods.add(invokedMethod)) { + pendingMethods.addAll(invokedMethod) + } + } + } + + private fun addArgValues(ref: ArgRef, value: IntValueSet): Boolean { + val old = argValues[ref] + + val new = if (value.singleton != null) { + if (old != null) { + old union value + } else { + value + } + } else { + IntValueSet.Unknown + } + argValues[ref] = new + + return old != new + } + + private fun simplifyBranches(owner: ClassNode, method: MethodNode, args: Array): Int { + val analyzer = Analyzer(IntInterpreter(args)) + val frames = analyzer.analyze(owner.name, method) + + val alwaysTakenBranches = mutableListOf() + val neverTakenBranches = mutableListOf() + + frame@ for ((i, frame) in frames.withIndex()) { + if (frame == null) { + continue + } + + val insn = method.instructions[i] + if (insn !is JumpInsnNode) { + continue + } + + when (insn.opcode) { + IFEQ, IFNE, IFLT, IFGE, IFGT, IFLE -> { + val value = frame.getStack(frame.stackSize - 1) + if (value.set !is IntValueSet.Constant) { + continue@frame + } + + @Suppress("NON_EXHAUSTIVE_WHEN") + when (IntBranch.evaluateUnary(insn.opcode, value.set.values)) { + ALWAYS_TAKEN -> alwaysTakenBranches += insn + NEVER_TAKEN -> neverTakenBranches += insn + } + } + IF_ICMPEQ, IF_ICMPNE, IF_ICMPLT, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE -> { + val value1 = frame.getStack(frame.stackSize - 2) + val value2 = frame.getStack(frame.stackSize - 1) + if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) { + continue@frame + } + + @Suppress("NON_EXHAUSTIVE_WHEN") + when (IntBranch.evaluateBinary(insn.opcode, value1.set.values, value2.set.values)) { + ALWAYS_TAKEN -> alwaysTakenBranches += insn + NEVER_TAKEN -> neverTakenBranches += insn + } + } + } + } + + var simplified = 0 + + for (insn in alwaysTakenBranches) { + val replacement = JumpInsnNode(GOTO, insn.label) + if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { + simplified++ + } + } + + for (insn in neverTakenBranches) { + if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) { + simplified++ + } + } + + return simplified + } + + private fun inlineConstantArgs(clazz: ClassNode, method: MethodNode, args: Array): Int { + val analyzer = Analyzer(IntInterpreter(args)) + val frames = analyzer.analyze(clazz.name, method) + + val constInsns = mutableMapOf() + + for ((i, frame) in frames.withIndex()) { + if (frame == null) { + continue + } + + val insn = method.instructions[i] + if (insn.intConstant != null) { + // already constant + continue + } else if (!insn.pure) { + // can't replace instructions with a side effect + continue + } else if (insn.stackMetadata().pushes != 1) { + // can't replace instructions pushing more than one result + continue + } + + // the value pushed by this instruction is held in the following frame + val nextInsn = insn.nextReal ?: continue + val nextInsnIndex = method.instructions.indexOf(nextInsn) + val nextFrame = frames[nextInsnIndex] + + val value = nextFrame.getStack(nextFrame.stackSize - 1) + val singleton = value.set.singleton + if (singleton != null) { + constInsns[insn] = singleton + } + } + + var inlined = 0 + + for ((insn, value) in constInsns) { + if (insn !in method.instructions) { + continue + } + + val replacement = createIntConstant(value) + if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { + inlined++ + } + } + + return inlined + } + + override fun transformCode(classPath: ClassPath, library: Library, clazz: ClassNode, method: MethodNode): Boolean { + val args = getArgs(MemberRef(clazz, method)) + branchesSimplified += simplifyBranches(clazz, method, args) + constantsInlined += inlineConstantArgs(clazz, method, args) + return false + } + + override fun postTransform(classPath: ClassPath) { + logger.info { "Simplified $branchesSimplified branches and inlined $constantsInlined constants" } + } + + companion object { + private val logger = InlineLogger() + } +} diff --git a/deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt b/deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt deleted file mode 100644 index 110257a47a..0000000000 --- a/deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt +++ /dev/null @@ -1,429 +0,0 @@ -package dev.openrs2.deob.transform - -import com.github.michaelbull.logging.InlineLogger -import com.google.common.collect.HashMultimap -import com.google.common.collect.Multimap -import dev.openrs2.asm.InsnMatcher -import dev.openrs2.asm.MemberRef -import dev.openrs2.asm.classpath.ClassPath -import dev.openrs2.asm.classpath.Library -import dev.openrs2.asm.createIntConstant -import dev.openrs2.asm.deleteExpression -import dev.openrs2.asm.intConstant -import dev.openrs2.asm.nextReal -import dev.openrs2.asm.pure -import dev.openrs2.asm.replaceExpression -import dev.openrs2.asm.stackMetadata -import dev.openrs2.asm.transform.Transformer -import dev.openrs2.deob.ArgRef -import dev.openrs2.deob.analysis.IntInterpreter -import dev.openrs2.deob.analysis.IntValue -import dev.openrs2.deob.analysis.SourcedIntValue -import dev.openrs2.util.collect.DisjointSet -import org.objectweb.asm.Opcodes -import org.objectweb.asm.Type -import org.objectweb.asm.tree.AbstractInsnNode -import org.objectweb.asm.tree.ClassNode -import org.objectweb.asm.tree.JumpInsnNode -import org.objectweb.asm.tree.MethodInsnNode -import org.objectweb.asm.tree.MethodNode -import org.objectweb.asm.tree.VarInsnNode -import org.objectweb.asm.tree.analysis.Analyzer - -class DummyArgTransformer : Transformer() { - private data class ConditionalCall( - val conditionVar: Int, - val conditionOpcode: Int, - val conditionValue: Int?, - val method: DisjointSet.Partition, - val constArgs: List - ) - - private enum class BranchResult { - ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; - - companion object { - fun fromTakenNotTaken(taken: Int, notTaken: Int): BranchResult { - require(taken != 0 || notTaken != 0) - - return when { - taken == 0 -> NEVER_TAKEN - notTaken == 0 -> ALWAYS_TAKEN - else -> UNKNOWN - } - } - } - } - - private val argValues: Multimap = HashMultimap.create() - private val conditionalCalls: Multimap?, ConditionalCall> = HashMultimap.create() - private val constArgs = mutableMapOf, Array?>>() - private lateinit var inheritedMethodSets: DisjointSet - private var branchesSimplified = 0 - private var constantsInlined = 0 - - private fun isMutuallyRecursiveDummy( - method: DisjointSet.Partition, - arg: Int, - source: DisjointSet.Partition, - value: Int - ): Boolean { - for (sourceToMethodCall in conditionalCalls[source]) { - if (sourceToMethodCall.method != method) { - continue - } - - for (methodToSourceCall in conditionalCalls[method]) { - if (methodToSourceCall.method != source || methodToSourceCall.conditionVar != arg) { - continue - } - - var taken = if (methodToSourceCall.conditionValue != null) { - evaluateBinaryBranch(methodToSourceCall.conditionOpcode, value, methodToSourceCall.conditionValue) - } else { - evaluateUnaryBranch(methodToSourceCall.conditionOpcode, value) - } - - if (taken) { - continue - } - - val constArg = methodToSourceCall.constArgs[sourceToMethodCall.conditionVar]!! - - taken = if (sourceToMethodCall.conditionValue != null) { - evaluateBinaryBranch( - sourceToMethodCall.conditionOpcode, - constArg, - sourceToMethodCall.conditionValue - ) - } else { - evaluateUnaryBranch(sourceToMethodCall.conditionOpcode, constArg) - } - - if (taken) { - continue - } - - return true - } - } - - return false - } - - private fun union( - method: DisjointSet.Partition, - arg: Int, - intValues: Collection - ): Set? { - val set = mutableSetOf() - - for ((source, intValue) in intValues) { - if (intValue !is IntValue.Constant) { - return null - } - - if (source == method) { - continue - } - - if (intValue.singleton != null) { - if (isMutuallyRecursiveDummy(method, arg, source, intValue.singleton)) { - continue - } - } - - set.addAll(intValue.values) - } - - return if (set.isEmpty()) { - null - } else { - set - } - } - - override fun preTransform(classPath: ClassPath) { - inheritedMethodSets = classPath.createInheritedMethodSets() - branchesSimplified = 0 - constantsInlined = 0 - } - - override fun prePass(classPath: ClassPath): Boolean { - argValues.clear() - conditionalCalls.clear() - return false - } - - override fun transformCode( - classPath: ClassPath, - library: Library, - clazz: ClassNode, - method: MethodNode - ): Boolean { - val parentMethod = inheritedMethodSets[MemberRef(clazz, method)]!! - - val stores = BooleanArray(method.maxLocals) - for (insn in method.instructions) { - if (insn is VarInsnNode && insn.opcode == Opcodes.ISTORE) { - stores[insn.`var`] = true - } - } - - for (match in CONDITIONAL_CALL_MATCHER.match(method)) { - var matchIndex = 0 - - val load = match[matchIndex++] as VarInsnNode - if (stores[load.`var`]) { - continue - } - - var callerSlots = Type.getArgumentsAndReturnSizes(method.desc) shr 2 - if (method.access and Opcodes.ACC_STATIC != 0) { - callerSlots-- - } - if (load.`var` >= callerSlots) { - continue - } - - val conditionValue: Int? - var conditionOpcode = match[matchIndex].opcode - if (conditionOpcode == Opcodes.IFEQ || conditionOpcode == Opcodes.IFNE) { - conditionValue = null - matchIndex++ - } else { - conditionValue = match[matchIndex++].intConstant - conditionOpcode = match[matchIndex++].opcode - } - - val invoke = match[match.size - 1] as MethodInsnNode - var invokeArgCount = Type.getArgumentTypes(invoke.desc).size - if (invoke.opcode != Opcodes.INVOKESTATIC) { - invokeArgCount++ - } - - val constArgs = arrayOfNulls(invokeArgCount) - for (i in constArgs.indices) { - val insn = match[matchIndex++] - if (insn.opcode == Opcodes.ACONST_NULL) { - matchIndex++ - } else { - constArgs[i] = insn.intConstant - } - } - - val callee = inheritedMethodSets[MemberRef(invoke)] ?: continue - conditionalCalls.put( - parentMethod, - ConditionalCall(load.`var`, conditionOpcode, conditionValue, callee, constArgs.asList()) - ) - } - - val parameters = constArgs[parentMethod] - val analyzer = Analyzer(IntInterpreter(parameters)) - val frames = analyzer.analyze(clazz.name, method) - - var changed = false - val alwaysTakenBranches = mutableListOf() - val neverTakenBranches = mutableListOf() - val constInsns = mutableMapOf() - - frame@ for ((i, frame) in frames.withIndex()) { - if (frame == null) { - continue - } - - val stackSize = frame.stackSize - - val insn = method.instructions[i] - when (insn.opcode) { - Opcodes.INVOKEVIRTUAL, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEINTERFACE -> { - val invoke = insn as MethodInsnNode - val invokedMethod = inheritedMethodSets[MemberRef(invoke)] ?: continue@frame - val args = Type.getArgumentTypes(invoke.desc).size - - var k = 0 - for (j in 0 until args) { - val arg = frame.getStack(stackSize - args + j) - argValues.put(ArgRef(invokedMethod, k), SourcedIntValue(parentMethod, arg)) - k += arg.size - } - } - Opcodes.IFEQ, Opcodes.IFNE -> { - val value = frame.getStack(stackSize - 1) - if (value !is IntValue.Constant) { - continue@frame - } - - val result = evaluateUnaryBranch(insn.opcode, value.values) - @Suppress("NON_EXHAUSTIVE_WHEN") - when (result) { - BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) - BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) - } - } - Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPLT, Opcodes.IF_ICMPGE, Opcodes.IF_ICMPGT, - Opcodes.IF_ICMPLE -> { - val value1 = frame.getStack(stackSize - 2) - val value2 = frame.getStack(stackSize - 1) - if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) { - continue@frame - } - - val result = evaluateBinaryBranch(insn.opcode, value1.values, value2.values) - @Suppress("NON_EXHAUSTIVE_WHEN") - when (result) { - BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) - BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) - } - } - else -> { - if (!insn.pure || insn.intConstant != null) { - continue@frame - } - - if (insn.stackMetadata().pushes != 1) { - continue@frame - } - - val nextInsn = insn.nextReal ?: continue@frame - val nextInsnIndex = method.instructions.indexOf(nextInsn) - val nextFrame = frames[nextInsnIndex] - - val value = nextFrame.getStack(nextFrame.stackSize - 1) - if (value is IntValue.Constant && value.singleton != null) { - constInsns[insn] = value.singleton - } - } - } - } - - for (insn in alwaysTakenBranches) { - val replacement = JumpInsnNode(Opcodes.GOTO, insn.label) - if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { - branchesSimplified++ - changed = true - } - } - - for (insn in neverTakenBranches) { - if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) { - branchesSimplified++ - changed = true - } - } - - for ((insn, value) in constInsns) { - if (insn !in method.instructions) { - continue - } - - val replacement = createIntConstant(value) - if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { - constantsInlined++ - changed = true - } - } - - return changed - } - - override fun postPass(classPath: ClassPath): Boolean { - for (method in inheritedMethodSets) { - val args = (Type.getArgumentsAndReturnSizes(method.first().desc) shr 2) - 1 - - var allUnknown = true - val parameters = arrayOfNulls?>(args) - - for (i in 0 until args) { - val parameter = union(method, i, argValues[ArgRef(method, i)]) - if (parameter != null) { - allUnknown = false - } - parameters[i] = parameter - } - - if (allUnknown) { - constArgs.remove(method) - } else { - constArgs[method] = parameters - } - } - - return false - } - - override fun postTransform(classPath: ClassPath) { - logger.info { "Simplified $branchesSimplified dummy branches and inlined $constantsInlined constants" } - } - - companion object { - private val logger = InlineLogger() - private val CONDITIONAL_CALL_MATCHER = InsnMatcher.compile( - """ - ILOAD - (IFEQ | IFNE | - (ICONST | BIPUSH | SIPUSH | LDC) - (IF_ICMPEQ | IF_ICMPNE | IF_ICMPLT | IF_ICMPGE | IF_ICMPGT | IF_ICMPLE) - ) - ALOAD? - (ICONST | FCONST | DCONST | BIPUSH | SIPUSH | LDC | ACONST_NULL CHECKCAST)+ - (INVOKEVIRTUAL | INVOKESTATIC | INVOKEINTERFACE) - """ - ) - - private fun evaluateUnaryBranch(opcode: Int, values: Set): BranchResult { - require(values.isNotEmpty()) - - var taken = 0 - var notTaken = 0 - for (v in values) { - if (evaluateUnaryBranch(opcode, v)) { - taken++ - } else { - notTaken++ - } - } - - return BranchResult.fromTakenNotTaken(taken, notTaken) - } - - private fun evaluateUnaryBranch(opcode: Int, value: Int): Boolean { - return when (opcode) { - Opcodes.IFEQ -> value == 0 - Opcodes.IFNE -> value != 0 - else -> throw IllegalArgumentException() - } - } - - private fun evaluateBinaryBranch(opcode: Int, values1: Set, values2: Set): BranchResult { - require(values1.isNotEmpty() && values2.isNotEmpty()) - - var taken = 0 - var notTaken = 0 - for (v1 in values1) { - for (v2 in values2) { - if (evaluateBinaryBranch(opcode, v1, v2)) { - taken++ - } else { - notTaken++ - } - } - } - - return BranchResult.fromTakenNotTaken(taken, notTaken) - } - - private fun evaluateBinaryBranch(opcode: Int, value1: Int, value2: Int): Boolean { - return when (opcode) { - Opcodes.IF_ICMPEQ -> value1 == value2 - Opcodes.IF_ICMPNE -> value1 != value2 - Opcodes.IF_ICMPLT -> value1 < value2 - Opcodes.IF_ICMPGE -> value1 >= value2 - Opcodes.IF_ICMPGT -> value1 > value2 - Opcodes.IF_ICMPLE -> value1 <= value2 - else -> throw IllegalArgumentException() - } - } - } -}