forked from openrs2/openrs2
The new transformer uses a different approach to the old one. It starts exploring the call graph from the entry points, recursively analysing method calls. Methods are only re-analysed if their possible argument values change, with the Unknown value being used if we can't identify a single integer constant at a call site. This prevents us from recursing infinitely if the client code does. While this first pass does simplify branches in order to ignore dummy method calls that are never evaluated at runtime, it operates on a copy of the method (as we initially ignore more calls while the argument value sets are smaller, ignoring fewer calls as they build up). A separate second pass simplifies branches on the original method and inlines singleton constants, paving the way for the UnusedArgTransformer to actually remove the newly unused arguments. This new approach has several benefits: - It is much faster than the old approach, as we only re-analyse methods as required by argument value changes, rather than re-analysing every method during every pass. - It doesn't require special cases for dealing with mutually recursive dummy calls. The old approach hard-coded special cases for mutually recursive calls involving groups of 1 and 2 methods. The code for this wasn't clean. Furthermore, while it was just about good enough for the HD client, the SD client contains a mutually recursive group of 3 methods. The new approach is capable of dealing with mutually recursive groups of any size. Finally, the new transformer has a much cleaner implementation. Signed-off-by: Graham <gpe@openrs2.dev>bzip2
parent
0626fd5133
commit
ff594848d5
@ -0,0 +1,78 @@ |
||||
package dev.openrs2.deob.analysis |
||||
|
||||
import dev.openrs2.deob.analysis.IntBranchResult.Companion.fromTakenNotTaken |
||||
import org.objectweb.asm.Opcodes.IFEQ |
||||
import org.objectweb.asm.Opcodes.IFGE |
||||
import org.objectweb.asm.Opcodes.IFGT |
||||
import org.objectweb.asm.Opcodes.IFLE |
||||
import org.objectweb.asm.Opcodes.IFLT |
||||
import org.objectweb.asm.Opcodes.IFNE |
||||
import org.objectweb.asm.Opcodes.IF_ICMPEQ |
||||
import org.objectweb.asm.Opcodes.IF_ICMPGE |
||||
import org.objectweb.asm.Opcodes.IF_ICMPGT |
||||
import org.objectweb.asm.Opcodes.IF_ICMPLE |
||||
import org.objectweb.asm.Opcodes.IF_ICMPLT |
||||
import org.objectweb.asm.Opcodes.IF_ICMPNE |
||||
|
||||
object IntBranch { |
||||
fun evaluateUnary(opcode: Int, values: Set<Int>): IntBranchResult { |
||||
require(values.isNotEmpty()) |
||||
|
||||
var taken = 0 |
||||
var notTaken = 0 |
||||
|
||||
for (v in values) { |
||||
if (evaluateUnary(opcode, v)) { |
||||
taken++ |
||||
} else { |
||||
notTaken++ |
||||
} |
||||
} |
||||
|
||||
return fromTakenNotTaken(taken, notTaken) |
||||
} |
||||
|
||||
private fun evaluateUnary(opcode: Int, value: Int): Boolean { |
||||
return when (opcode) { |
||||
IFEQ -> value == 0 |
||||
IFNE -> value != 0 |
||||
IFLT -> value < 0 |
||||
IFGE -> value >= 0 |
||||
IFGT -> value > 0 |
||||
IFLE -> value <= 0 |
||||
else -> throw IllegalArgumentException("Opcode $opcode is not a unary conditional branch instruction") |
||||
} |
||||
} |
||||
|
||||
fun evaluateBinary(opcode: Int, values1: Set<Int>, values2: Set<Int>): IntBranchResult { |
||||
require(values1.isNotEmpty()) |
||||
require(values2.isNotEmpty()) |
||||
|
||||
var taken = 0 |
||||
var notTaken = 0 |
||||
|
||||
for (v1 in values1) { |
||||
for (v2 in values2) { |
||||
if (evaluateBinary(opcode, v1, v2)) { |
||||
taken++ |
||||
} else { |
||||
notTaken++ |
||||
} |
||||
} |
||||
} |
||||
|
||||
return fromTakenNotTaken(taken, notTaken) |
||||
} |
||||
|
||||
private fun evaluateBinary(opcode: Int, value1: Int, value2: Int): Boolean { |
||||
return when (opcode) { |
||||
IF_ICMPEQ -> value1 == value2 |
||||
IF_ICMPNE -> value1 != value2 |
||||
IF_ICMPLT -> value1 < value2 |
||||
IF_ICMPGE -> value1 >= value2 |
||||
IF_ICMPGT -> value1 > value2 |
||||
IF_ICMPLE -> value1 <= value2 |
||||
else -> throw IllegalArgumentException("Opcode $opcode is not a binary conditional branch instruction") |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,17 @@ |
||||
package dev.openrs2.deob.analysis |
||||
|
||||
enum class IntBranchResult { |
||||
ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; |
||||
|
||||
companion object { |
||||
fun fromTakenNotTaken(taken: Int, notTaken: Int): IntBranchResult { |
||||
require(taken != 0 || notTaken != 0) |
||||
|
||||
return when { |
||||
taken == 0 -> NEVER_TAKEN |
||||
notTaken == 0 -> ALWAYS_TAKEN |
||||
else -> UNKNOWN |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,42 @@ |
||||
package dev.openrs2.deob.analysis |
||||
|
||||
sealed class IntValueSet { |
||||
data class Constant(val values: Set<Int>) : IntValueSet() { |
||||
init { |
||||
require(values.isNotEmpty()) |
||||
} |
||||
|
||||
override val singleton: Int? |
||||
get() = if (values.size == 1) { |
||||
values.first() |
||||
} else { |
||||
null |
||||
} |
||||
|
||||
override fun union(other: IntValueSet): IntValueSet { |
||||
return if (other is Constant) { |
||||
Constant(values union other.values) |
||||
} else { |
||||
Unknown |
||||
} |
||||
} |
||||
} |
||||
|
||||
object Unknown : IntValueSet() { |
||||
override val singleton: Int? |
||||
get() = null |
||||
|
||||
override fun union(other: IntValueSet): IntValueSet { |
||||
return Unknown |
||||
} |
||||
} |
||||
|
||||
abstract val singleton: Int? |
||||
abstract infix fun union(other: IntValueSet): IntValueSet |
||||
|
||||
companion object { |
||||
fun singleton(value: Int): IntValueSet { |
||||
return Constant(setOf(value)) |
||||
} |
||||
} |
||||
} |
@ -1,6 +0,0 @@ |
||||
package dev.openrs2.deob.analysis |
||||
|
||||
import dev.openrs2.asm.MemberRef |
||||
import dev.openrs2.util.collect.DisjointSet |
||||
|
||||
data class SourcedIntValue(val source: DisjointSet.Partition<MemberRef>, val intValue: IntValue) |
@ -0,0 +1,310 @@ |
||||
package dev.openrs2.deob.transform |
||||
|
||||
import com.github.michaelbull.logging.InlineLogger |
||||
import dev.openrs2.asm.MemberRef |
||||
import dev.openrs2.asm.classpath.ClassPath |
||||
import dev.openrs2.asm.classpath.Library |
||||
import dev.openrs2.asm.createIntConstant |
||||
import dev.openrs2.asm.deleteExpression |
||||
import dev.openrs2.asm.hasCode |
||||
import dev.openrs2.asm.intConstant |
||||
import dev.openrs2.asm.nextReal |
||||
import dev.openrs2.asm.pure |
||||
import dev.openrs2.asm.replaceExpression |
||||
import dev.openrs2.asm.stackMetadata |
||||
import dev.openrs2.asm.transform.Transformer |
||||
import dev.openrs2.deob.ArgRef |
||||
import dev.openrs2.deob.analysis.IntBranch |
||||
import dev.openrs2.deob.analysis.IntBranchResult.ALWAYS_TAKEN |
||||
import dev.openrs2.deob.analysis.IntBranchResult.NEVER_TAKEN |
||||
import dev.openrs2.deob.analysis.IntInterpreter |
||||
import dev.openrs2.deob.analysis.IntValueSet |
||||
import dev.openrs2.deob.remap.TypedRemapper |
||||
import dev.openrs2.util.collect.DisjointSet |
||||
import dev.openrs2.util.collect.removeFirstOrNull |
||||
import org.objectweb.asm.Opcodes.GOTO |
||||
import org.objectweb.asm.Opcodes.IFEQ |
||||
import org.objectweb.asm.Opcodes.IFGE |
||||
import org.objectweb.asm.Opcodes.IFGT |
||||
import org.objectweb.asm.Opcodes.IFLE |
||||
import org.objectweb.asm.Opcodes.IFLT |
||||
import org.objectweb.asm.Opcodes.IFNE |
||||
import org.objectweb.asm.Opcodes.IF_ICMPEQ |
||||
import org.objectweb.asm.Opcodes.IF_ICMPGE |
||||
import org.objectweb.asm.Opcodes.IF_ICMPGT |
||||
import org.objectweb.asm.Opcodes.IF_ICMPLE |
||||
import org.objectweb.asm.Opcodes.IF_ICMPLT |
||||
import org.objectweb.asm.Opcodes.IF_ICMPNE |
||||
import org.objectweb.asm.Type |
||||
import org.objectweb.asm.tree.AbstractInsnNode |
||||
import org.objectweb.asm.tree.ClassNode |
||||
import org.objectweb.asm.tree.JumpInsnNode |
||||
import org.objectweb.asm.tree.MethodInsnNode |
||||
import org.objectweb.asm.tree.MethodNode |
||||
import org.objectweb.asm.tree.analysis.Analyzer |
||||
|
||||
class ConstantArgTransformer : Transformer() { |
||||
private val pendingMethods = LinkedHashSet<MemberRef>() |
||||
private val arglessMethods = mutableSetOf<DisjointSet.Partition<MemberRef>>() |
||||
private val argValues = mutableMapOf<ArgRef, IntValueSet>() |
||||
private lateinit var inheritedMethodSets: DisjointSet<MemberRef> |
||||
private var branchesSimplified = 0 |
||||
private var constantsInlined = 0 |
||||
|
||||
override fun preTransform(classPath: ClassPath) { |
||||
pendingMethods.clear() |
||||
arglessMethods.clear() |
||||
argValues.clear() |
||||
inheritedMethodSets = classPath.createInheritedMethodSets() |
||||
branchesSimplified = 0 |
||||
constantsInlined = 0 |
||||
|
||||
queueEntryPoints(classPath) |
||||
|
||||
while (true) { |
||||
val method = pendingMethods.removeFirstOrNull() ?: break |
||||
analyzeMethod(classPath, method) |
||||
} |
||||
} |
||||
|
||||
private fun queueEntryPoints(classPath: ClassPath) { |
||||
for (partition in inheritedMethodSets) { |
||||
/* |
||||
* The EXCLUDED_METHODS set roughly matches up with the methods we |
||||
* want to consider as entry points. It isn't perfect - it counts |
||||
* every <init> method as an entry point, but strictly speaking we |
||||
* only need to count <init> methods invoked with reflection as |
||||
* entry points (like VisibilityTransformer). However, it makes no |
||||
* difference in this case, as the deobfuscator does not add dummy |
||||
* constant arguments to constructors. |
||||
*/ |
||||
val excluded = partition.first().name in TypedRemapper.EXCLUDED_METHODS |
||||
val overridesDependency = partition.any { classPath[it.owner]!!.dependency } |
||||
if (excluded || overridesDependency) { |
||||
pendingMethods.addAll(partition) |
||||
} |
||||
} |
||||
} |
||||
|
||||
private fun analyzeMethod(classPath: ClassPath, ref: MemberRef) { |
||||
// find ClassNode/MethodNode |
||||
val owner = classPath.getNode(ref.owner) ?: return |
||||
val originalMethod = owner.methods.singleOrNull { it.name == ref.name && it.desc == ref.desc } ?: return |
||||
if (!originalMethod.hasCode()) { |
||||
return |
||||
} |
||||
|
||||
/* |
||||
* Clone the method - we don't want to mutate it permanently until the |
||||
* final pass, as we might discover more routes through the call graph |
||||
* later which reduce the number of branches we can simplify. |
||||
*/ |
||||
val method = MethodNode( |
||||
originalMethod.access, |
||||
originalMethod.name, |
||||
originalMethod.desc, |
||||
originalMethod.signature, |
||||
originalMethod.exceptions?.toTypedArray() |
||||
) |
||||
originalMethod.accept(method) |
||||
|
||||
// find existing constant arguments |
||||
val args = getArgs(ref) |
||||
|
||||
// simplify branches |
||||
simplifyBranches(owner, method, args) |
||||
|
||||
/* |
||||
* Record new constant arguments in method calls. This re-runs the |
||||
* analyzer rather than re-using the frames from simplifyBranches. This |
||||
* ensures we ignore branches that always evaluate to false, preventing |
||||
* us from recording constant arguments found in dummy calls (which |
||||
* would prevent us from removing further dummy calls/branches). |
||||
*/ |
||||
addArgValues(owner, method, args) |
||||
} |
||||
|
||||
private fun getArgs(ref: MemberRef): Array<IntValueSet> { |
||||
val partition = inheritedMethodSets[ref]!! |
||||
val size = Type.getArgumentTypes(ref.desc).sumBy { it.size } |
||||
return Array(size) { i -> argValues[ArgRef(partition, i)] ?: IntValueSet.Unknown } |
||||
} |
||||
|
||||
private fun addArgValues(owner: ClassNode, method: MethodNode, args: Array<IntValueSet>) { |
||||
val analyzer = Analyzer(IntInterpreter(args)) |
||||
val frames = analyzer.analyze(owner.name, method) |
||||
for ((i, frame) in frames.withIndex()) { |
||||
if (frame == null) { |
||||
continue |
||||
} |
||||
|
||||
val insn = method.instructions[i] |
||||
if (insn !is MethodInsnNode) { |
||||
continue |
||||
} |
||||
|
||||
val invokedMethod = inheritedMethodSets[MemberRef(insn)] ?: continue |
||||
val size = Type.getArgumentTypes(insn.desc).size |
||||
|
||||
var index = 0 |
||||
for (j in 0 until size) { |
||||
val value = frame.getStack(frame.stackSize - size + j) |
||||
if (addArgValues(ArgRef(invokedMethod, index), value.set)) { |
||||
pendingMethods.addAll(invokedMethod) |
||||
} |
||||
index += value.size |
||||
} |
||||
|
||||
if (size == 0 && arglessMethods.add(invokedMethod)) { |
||||
pendingMethods.addAll(invokedMethod) |
||||
} |
||||
} |
||||
} |
||||
|
||||
private fun addArgValues(ref: ArgRef, value: IntValueSet): Boolean { |
||||
val old = argValues[ref] |
||||
|
||||
val new = if (value.singleton != null) { |
||||
if (old != null) { |
||||
old union value |
||||
} else { |
||||
value |
||||
} |
||||
} else { |
||||
IntValueSet.Unknown |
||||
} |
||||
argValues[ref] = new |
||||
|
||||
return old != new |
||||
} |
||||
|
||||
private fun simplifyBranches(owner: ClassNode, method: MethodNode, args: Array<IntValueSet>): Int { |
||||
val analyzer = Analyzer(IntInterpreter(args)) |
||||
val frames = analyzer.analyze(owner.name, method) |
||||
|
||||
val alwaysTakenBranches = mutableListOf<JumpInsnNode>() |
||||
val neverTakenBranches = mutableListOf<JumpInsnNode>() |
||||
|
||||
frame@ for ((i, frame) in frames.withIndex()) { |
||||
if (frame == null) { |
||||
continue |
||||
} |
||||
|
||||
val insn = method.instructions[i] |
||||
if (insn !is JumpInsnNode) { |
||||
continue |
||||
} |
||||
|
||||
when (insn.opcode) { |
||||
IFEQ, IFNE, IFLT, IFGE, IFGT, IFLE -> { |
||||
val value = frame.getStack(frame.stackSize - 1) |
||||
if (value.set !is IntValueSet.Constant) { |
||||
continue@frame |
||||
} |
||||
|
||||
@Suppress("NON_EXHAUSTIVE_WHEN") |
||||
when (IntBranch.evaluateUnary(insn.opcode, value.set.values)) { |
||||
ALWAYS_TAKEN -> alwaysTakenBranches += insn |
||||
NEVER_TAKEN -> neverTakenBranches += insn |
||||
} |
||||
} |
||||
IF_ICMPEQ, IF_ICMPNE, IF_ICMPLT, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE -> { |
||||
val value1 = frame.getStack(frame.stackSize - 2) |
||||
val value2 = frame.getStack(frame.stackSize - 1) |
||||
if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) { |
||||
continue@frame |
||||
} |
||||
|
||||
@Suppress("NON_EXHAUSTIVE_WHEN") |
||||
when (IntBranch.evaluateBinary(insn.opcode, value1.set.values, value2.set.values)) { |
||||
ALWAYS_TAKEN -> alwaysTakenBranches += insn |
||||
NEVER_TAKEN -> neverTakenBranches += insn |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
var simplified = 0 |
||||
|
||||
for (insn in alwaysTakenBranches) { |
||||
val replacement = JumpInsnNode(GOTO, insn.label) |
||||
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { |
||||
simplified++ |
||||
} |
||||
} |
||||
|
||||
for (insn in neverTakenBranches) { |
||||
if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) { |
||||
simplified++ |
||||
} |
||||
} |
||||
|
||||
return simplified |
||||
} |
||||
|
||||
private fun inlineConstantArgs(clazz: ClassNode, method: MethodNode, args: Array<IntValueSet>): Int { |
||||
val analyzer = Analyzer(IntInterpreter(args)) |
||||
val frames = analyzer.analyze(clazz.name, method) |
||||
|
||||
val constInsns = mutableMapOf<AbstractInsnNode, Int>() |
||||
|
||||
for ((i, frame) in frames.withIndex()) { |
||||
if (frame == null) { |
||||
continue |
||||
} |
||||
|
||||
val insn = method.instructions[i] |
||||
if (insn.intConstant != null) { |
||||
// already constant |
||||
continue |
||||
} else if (!insn.pure) { |
||||
// can't replace instructions with a side effect |
||||
continue |
||||
} else if (insn.stackMetadata().pushes != 1) { |
||||
// can't replace instructions pushing more than one result |
||||
continue |
||||
} |
||||
|
||||
// the value pushed by this instruction is held in the following frame |
||||
val nextInsn = insn.nextReal ?: continue |
||||
val nextInsnIndex = method.instructions.indexOf(nextInsn) |
||||
val nextFrame = frames[nextInsnIndex] |
||||
|
||||
val value = nextFrame.getStack(nextFrame.stackSize - 1) |
||||
val singleton = value.set.singleton |
||||
if (singleton != null) { |
||||
constInsns[insn] = singleton |
||||
} |
||||
} |
||||
|
||||
var inlined = 0 |
||||
|
||||
for ((insn, value) in constInsns) { |
||||
if (insn !in method.instructions) { |
||||
continue |
||||
} |
||||
|
||||
val replacement = createIntConstant(value) |
||||
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { |
||||
inlined++ |
||||
} |
||||
} |
||||
|
||||
return inlined |
||||
} |
||||
|
||||
override fun transformCode(classPath: ClassPath, library: Library, clazz: ClassNode, method: MethodNode): Boolean { |
||||
val args = getArgs(MemberRef(clazz, method)) |
||||
branchesSimplified += simplifyBranches(clazz, method, args) |
||||
constantsInlined += inlineConstantArgs(clazz, method, args) |
||||
return false |
||||
} |
||||
|
||||
override fun postTransform(classPath: ClassPath) { |
||||
logger.info { "Simplified $branchesSimplified branches and inlined $constantsInlined constants" } |
||||
} |
||||
|
||||
companion object { |
||||
private val logger = InlineLogger() |
||||
} |
||||
} |
@ -1,429 +0,0 @@ |
||||
package dev.openrs2.deob.transform |
||||
|
||||
import com.github.michaelbull.logging.InlineLogger |
||||
import com.google.common.collect.HashMultimap |
||||
import com.google.common.collect.Multimap |
||||
import dev.openrs2.asm.InsnMatcher |
||||
import dev.openrs2.asm.MemberRef |
||||
import dev.openrs2.asm.classpath.ClassPath |
||||
import dev.openrs2.asm.classpath.Library |
||||
import dev.openrs2.asm.createIntConstant |
||||
import dev.openrs2.asm.deleteExpression |
||||
import dev.openrs2.asm.intConstant |
||||
import dev.openrs2.asm.nextReal |
||||
import dev.openrs2.asm.pure |
||||
import dev.openrs2.asm.replaceExpression |
||||
import dev.openrs2.asm.stackMetadata |
||||
import dev.openrs2.asm.transform.Transformer |
||||
import dev.openrs2.deob.ArgRef |
||||
import dev.openrs2.deob.analysis.IntInterpreter |
||||
import dev.openrs2.deob.analysis.IntValue |
||||
import dev.openrs2.deob.analysis.SourcedIntValue |
||||
import dev.openrs2.util.collect.DisjointSet |
||||
import org.objectweb.asm.Opcodes |
||||
import org.objectweb.asm.Type |
||||
import org.objectweb.asm.tree.AbstractInsnNode |
||||
import org.objectweb.asm.tree.ClassNode |
||||
import org.objectweb.asm.tree.JumpInsnNode |
||||
import org.objectweb.asm.tree.MethodInsnNode |
||||
import org.objectweb.asm.tree.MethodNode |
||||
import org.objectweb.asm.tree.VarInsnNode |
||||
import org.objectweb.asm.tree.analysis.Analyzer |
||||
|
||||
class DummyArgTransformer : Transformer() { |
||||
private data class ConditionalCall( |
||||
val conditionVar: Int, |
||||
val conditionOpcode: Int, |
||||
val conditionValue: Int?, |
||||
val method: DisjointSet.Partition<MemberRef>, |
||||
val constArgs: List<Int?> |
||||
) |
||||
|
||||
private enum class BranchResult { |
||||
ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; |
||||
|
||||
companion object { |
||||
fun fromTakenNotTaken(taken: Int, notTaken: Int): BranchResult { |
||||
require(taken != 0 || notTaken != 0) |
||||
|
||||
return when { |
||||
taken == 0 -> NEVER_TAKEN |
||||
notTaken == 0 -> ALWAYS_TAKEN |
||||
else -> UNKNOWN |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
private val argValues: Multimap<ArgRef, SourcedIntValue> = HashMultimap.create() |
||||
private val conditionalCalls: Multimap<DisjointSet.Partition<MemberRef>?, ConditionalCall> = HashMultimap.create() |
||||
private val constArgs = mutableMapOf<DisjointSet.Partition<MemberRef>, Array<Set<Int>?>>() |
||||
private lateinit var inheritedMethodSets: DisjointSet<MemberRef> |
||||
private var branchesSimplified = 0 |
||||
private var constantsInlined = 0 |
||||
|
||||
private fun isMutuallyRecursiveDummy( |
||||
method: DisjointSet.Partition<MemberRef>, |
||||
arg: Int, |
||||
source: DisjointSet.Partition<MemberRef>, |
||||
value: Int |
||||
): Boolean { |
||||
for (sourceToMethodCall in conditionalCalls[source]) { |
||||
if (sourceToMethodCall.method != method) { |
||||
continue |
||||
} |
||||
|
||||
for (methodToSourceCall in conditionalCalls[method]) { |
||||
if (methodToSourceCall.method != source || methodToSourceCall.conditionVar != arg) { |
||||
continue |
||||
} |
||||
|
||||
var taken = if (methodToSourceCall.conditionValue != null) { |
||||
evaluateBinaryBranch(methodToSourceCall.conditionOpcode, value, methodToSourceCall.conditionValue) |
||||
} else { |
||||
evaluateUnaryBranch(methodToSourceCall.conditionOpcode, value) |
||||
} |
||||
|
||||
if (taken) { |
||||
continue |
||||
} |
||||
|
||||
val constArg = methodToSourceCall.constArgs[sourceToMethodCall.conditionVar]!! |
||||
|
||||
taken = if (sourceToMethodCall.conditionValue != null) { |
||||
evaluateBinaryBranch( |
||||
sourceToMethodCall.conditionOpcode, |
||||
constArg, |
||||
sourceToMethodCall.conditionValue |
||||
) |
||||
} else { |
||||
evaluateUnaryBranch(sourceToMethodCall.conditionOpcode, constArg) |
||||
} |
||||
|
||||
if (taken) { |
||||
continue |
||||
} |
||||
|
||||
return true |
||||
} |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
private fun union( |
||||
method: DisjointSet.Partition<MemberRef>, |
||||
arg: Int, |
||||
intValues: Collection<SourcedIntValue> |
||||
): Set<Int>? { |
||||
val set = mutableSetOf<Int>() |
||||
|
||||
for ((source, intValue) in intValues) { |
||||
if (intValue !is IntValue.Constant) { |
||||
return null |
||||
} |
||||
|
||||
if (source == method) { |
||||
continue |
||||
} |
||||
|
||||
if (intValue.singleton != null) { |
||||
if (isMutuallyRecursiveDummy(method, arg, source, intValue.singleton)) { |
||||
continue |
||||
} |
||||
} |
||||
|
||||
set.addAll(intValue.values) |
||||
} |
||||
|
||||
return if (set.isEmpty()) { |
||||
null |
||||
} else { |
||||
set |
||||
} |
||||
} |
||||
|
||||
override fun preTransform(classPath: ClassPath) { |
||||
inheritedMethodSets = classPath.createInheritedMethodSets() |
||||
branchesSimplified = 0 |
||||
constantsInlined = 0 |
||||
} |
||||
|
||||
override fun prePass(classPath: ClassPath): Boolean { |
||||
argValues.clear() |
||||
conditionalCalls.clear() |
||||
return false |
||||
} |
||||
|
||||
override fun transformCode( |
||||
classPath: ClassPath, |
||||
library: Library, |
||||
clazz: ClassNode, |
||||
method: MethodNode |
||||
): Boolean { |
||||
val parentMethod = inheritedMethodSets[MemberRef(clazz, method)]!! |
||||
|
||||
val stores = BooleanArray(method.maxLocals) |
||||
for (insn in method.instructions) { |
||||
if (insn is VarInsnNode && insn.opcode == Opcodes.ISTORE) { |
||||
stores[insn.`var`] = true |
||||
} |
||||
} |
||||
|
||||
for (match in CONDITIONAL_CALL_MATCHER.match(method)) { |
||||
var matchIndex = 0 |
||||
|
||||
val load = match[matchIndex++] as VarInsnNode |
||||
if (stores[load.`var`]) { |
||||
continue |
||||
} |
||||
|
||||
var callerSlots = Type.getArgumentsAndReturnSizes(method.desc) shr 2 |
||||
if (method.access and Opcodes.ACC_STATIC != 0) { |
||||
callerSlots-- |
||||
} |
||||
if (load.`var` >= callerSlots) { |
||||
continue |
||||
} |
||||
|
||||
val conditionValue: Int? |
||||
var conditionOpcode = match[matchIndex].opcode |
||||
if (conditionOpcode == Opcodes.IFEQ || conditionOpcode == Opcodes.IFNE) { |
||||
conditionValue = null |
||||
matchIndex++ |
||||
} else { |
||||
conditionValue = match[matchIndex++].intConstant |
||||
conditionOpcode = match[matchIndex++].opcode |
||||
} |
||||
|
||||
val invoke = match[match.size - 1] as MethodInsnNode |
||||
var invokeArgCount = Type.getArgumentTypes(invoke.desc).size |
||||
if (invoke.opcode != Opcodes.INVOKESTATIC) { |
||||
invokeArgCount++ |
||||
} |
||||
|
||||
val constArgs = arrayOfNulls<Int>(invokeArgCount) |
||||
for (i in constArgs.indices) { |
||||
val insn = match[matchIndex++] |
||||
if (insn.opcode == Opcodes.ACONST_NULL) { |
||||
matchIndex++ |
||||
} else { |
||||
constArgs[i] = insn.intConstant |
||||
} |
||||
} |
||||
|
||||
val callee = inheritedMethodSets[MemberRef(invoke)] ?: continue |
||||
conditionalCalls.put( |
||||
parentMethod, |
||||
ConditionalCall(load.`var`, conditionOpcode, conditionValue, callee, constArgs.asList()) |
||||
) |
||||
} |
||||
|
||||
val parameters = constArgs[parentMethod] |
||||
val analyzer = Analyzer(IntInterpreter(parameters)) |
||||
val frames = analyzer.analyze(clazz.name, method) |
||||
|
||||
var changed = false |
||||
val alwaysTakenBranches = mutableListOf<JumpInsnNode>() |
||||
val neverTakenBranches = mutableListOf<JumpInsnNode>() |
||||
val constInsns = mutableMapOf<AbstractInsnNode, Int>() |
||||
|
||||
frame@ for ((i, frame) in frames.withIndex()) { |
||||
if (frame == null) { |
||||
continue |
||||
} |
||||
|
||||
val stackSize = frame.stackSize |
||||
|
||||
val insn = method.instructions[i] |
||||
when (insn.opcode) { |
||||
Opcodes.INVOKEVIRTUAL, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEINTERFACE -> { |
||||
val invoke = insn as MethodInsnNode |
||||
val invokedMethod = inheritedMethodSets[MemberRef(invoke)] ?: continue@frame |
||||
val args = Type.getArgumentTypes(invoke.desc).size |
||||
|
||||
var k = 0 |
||||
for (j in 0 until args) { |
||||
val arg = frame.getStack(stackSize - args + j) |
||||
argValues.put(ArgRef(invokedMethod, k), SourcedIntValue(parentMethod, arg)) |
||||
k += arg.size |
||||
} |
||||
} |
||||
Opcodes.IFEQ, Opcodes.IFNE -> { |
||||
val value = frame.getStack(stackSize - 1) |
||||
if (value !is IntValue.Constant) { |
||||
continue@frame |
||||
} |
||||
|
||||
val result = evaluateUnaryBranch(insn.opcode, value.values) |
||||
@Suppress("NON_EXHAUSTIVE_WHEN") |
||||
when (result) { |
||||
BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) |
||||
BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) |
||||
} |
||||
} |
||||
Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPLT, Opcodes.IF_ICMPGE, Opcodes.IF_ICMPGT, |
||||
Opcodes.IF_ICMPLE -> { |
||||
val value1 = frame.getStack(stackSize - 2) |
||||
val value2 = frame.getStack(stackSize - 1) |
||||
if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) { |
||||
continue@frame |
||||
} |
||||
|
||||
val result = evaluateBinaryBranch(insn.opcode, value1.values, value2.values) |
||||
@Suppress("NON_EXHAUSTIVE_WHEN") |
||||
when (result) { |
||||
BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) |
||||
BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) |
||||
} |
||||
} |
||||
else -> { |
||||
if (!insn.pure || insn.intConstant != null) { |
||||
continue@frame |
||||
} |
||||
|
||||
if (insn.stackMetadata().pushes != 1) { |
||||
continue@frame |
||||
} |
||||
|
||||
val nextInsn = insn.nextReal ?: continue@frame |
||||
val nextInsnIndex = method.instructions.indexOf(nextInsn) |
||||
val nextFrame = frames[nextInsnIndex] |
||||
|
||||
val value = nextFrame.getStack(nextFrame.stackSize - 1) |
||||
if (value is IntValue.Constant && value.singleton != null) { |
||||
constInsns[insn] = value.singleton |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
for (insn in alwaysTakenBranches) { |
||||
val replacement = JumpInsnNode(Opcodes.GOTO, insn.label) |
||||
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { |
||||
branchesSimplified++ |
||||
changed = true |
||||
} |
||||
} |
||||
|
||||
for (insn in neverTakenBranches) { |
||||
if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) { |
||||
branchesSimplified++ |
||||
changed = true |
||||
} |
||||
} |
||||
|
||||
for ((insn, value) in constInsns) { |
||||
if (insn !in method.instructions) { |
||||
continue |
||||
} |
||||
|
||||
val replacement = createIntConstant(value) |
||||
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) { |
||||
constantsInlined++ |
||||
changed = true |
||||
} |
||||
} |
||||
|
||||
return changed |
||||
} |
||||
|
||||
override fun postPass(classPath: ClassPath): Boolean { |
||||
for (method in inheritedMethodSets) { |
||||
val args = (Type.getArgumentsAndReturnSizes(method.first().desc) shr 2) - 1 |
||||
|
||||
var allUnknown = true |
||||
val parameters = arrayOfNulls<Set<Int>?>(args) |
||||
|
||||
for (i in 0 until args) { |
||||
val parameter = union(method, i, argValues[ArgRef(method, i)]) |
||||
if (parameter != null) { |
||||
allUnknown = false |
||||
} |
||||
parameters[i] = parameter |
||||
} |
||||
|
||||
if (allUnknown) { |
||||
constArgs.remove(method) |
||||
} else { |
||||
constArgs[method] = parameters |
||||
} |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
override fun postTransform(classPath: ClassPath) { |
||||
logger.info { "Simplified $branchesSimplified dummy branches and inlined $constantsInlined constants" } |
||||
} |
||||
|
||||
companion object { |
||||
private val logger = InlineLogger() |
||||
private val CONDITIONAL_CALL_MATCHER = InsnMatcher.compile( |
||||
""" |
||||
ILOAD |
||||
(IFEQ | IFNE | |
||||
(ICONST | BIPUSH | SIPUSH | LDC) |
||||
(IF_ICMPEQ | IF_ICMPNE | IF_ICMPLT | IF_ICMPGE | IF_ICMPGT | IF_ICMPLE) |
||||
) |
||||
ALOAD? |
||||
(ICONST | FCONST | DCONST | BIPUSH | SIPUSH | LDC | ACONST_NULL CHECKCAST)+ |
||||
(INVOKEVIRTUAL | INVOKESTATIC | INVOKEINTERFACE) |
||||
""" |
||||
) |
||||
|
||||
private fun evaluateUnaryBranch(opcode: Int, values: Set<Int>): BranchResult { |
||||
require(values.isNotEmpty()) |
||||
|
||||
var taken = 0 |
||||
var notTaken = 0 |
||||
for (v in values) { |
||||
if (evaluateUnaryBranch(opcode, v)) { |
||||
taken++ |
||||
} else { |
||||
notTaken++ |
||||
} |
||||
} |
||||
|
||||
return BranchResult.fromTakenNotTaken(taken, notTaken) |
||||
} |
||||
|
||||
private fun evaluateUnaryBranch(opcode: Int, value: Int): Boolean { |
||||
return when (opcode) { |
||||
Opcodes.IFEQ -> value == 0 |
||||
Opcodes.IFNE -> value != 0 |
||||
else -> throw IllegalArgumentException() |
||||
} |
||||
} |
||||
|
||||
private fun evaluateBinaryBranch(opcode: Int, values1: Set<Int>, values2: Set<Int>): BranchResult { |
||||
require(values1.isNotEmpty() && values2.isNotEmpty()) |
||||
|
||||
var taken = 0 |
||||
var notTaken = 0 |
||||
for (v1 in values1) { |
||||
for (v2 in values2) { |
||||
if (evaluateBinaryBranch(opcode, v1, v2)) { |
||||
taken++ |
||||
} else { |
||||
notTaken++ |
||||
} |
||||
} |
||||
} |
||||
|
||||
return BranchResult.fromTakenNotTaken(taken, notTaken) |
||||
} |
||||
|
||||
private fun evaluateBinaryBranch(opcode: Int, value1: Int, value2: Int): Boolean { |
||||
return when (opcode) { |
||||
Opcodes.IF_ICMPEQ -> value1 == value2 |
||||
Opcodes.IF_ICMPNE -> value1 != value2 |
||||
Opcodes.IF_ICMPLT -> value1 < value2 |
||||
Opcodes.IF_ICMPGE -> value1 >= value2 |
||||
Opcodes.IF_ICMPGT -> value1 > value2 |
||||
Opcodes.IF_ICMPLE -> value1 <= value2 |
||||
else -> throw IllegalArgumentException() |
||||
} |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue