Replace DummyArgTransformer with new ConstantArgTransformer

The new transformer uses a different approach to the old one. It starts
exploring the call graph from the entry points, recursively analysing
method calls. Methods are only re-analysed if their possible argument
values change, with the Unknown value being used if we can't identify a
single integer constant at a call site. This prevents us from recursing
infinitely if the client code does.

While this first pass does simplify branches in order to ignore dummy
method calls that are never evaluated at runtime, it operates on a copy
of the method (as we initially ignore more calls while the argument
value sets are smaller, ignoring fewer calls as they build up).

A separate second pass simplifies branches on the original method and
inlines singleton constants, paving the way for the UnusedArgTransformer
to actually remove the newly unused arguments.

This new approach has several benefits:

- It is much faster than the old approach, as we only re-analyse methods
  as required by argument value changes, rather than re-analysing every
  method during every pass.

- It doesn't require special cases for dealing with mutually recursive
  dummy calls. The old approach hard-coded special cases for mutually
  recursive calls involving groups of 1 and 2 methods. The code for this
  wasn't clean. Furthermore, while it was just about good enough for the
  HD client, the SD client contains a mutually recursive group of 3
  methods. The new approach is capable of dealing with mutually
  recursive groups of any size.

Finally, the new transformer has a much cleaner implementation.

Signed-off-by: Graham <gpe@openrs2.dev>
pull/102/head
Graham 4 years ago
parent 0626fd5133
commit ff594848d5
  1. 4
      deob/src/main/java/dev/openrs2/deob/Deobfuscator.kt
  2. 78
      deob/src/main/java/dev/openrs2/deob/analysis/IntBranch.kt
  3. 17
      deob/src/main/java/dev/openrs2/deob/analysis/IntBranchResult.kt
  4. 67
      deob/src/main/java/dev/openrs2/deob/analysis/IntInterpreter.kt
  5. 17
      deob/src/main/java/dev/openrs2/deob/analysis/IntValue.kt
  6. 42
      deob/src/main/java/dev/openrs2/deob/analysis/IntValueSet.kt
  7. 6
      deob/src/main/java/dev/openrs2/deob/analysis/SourcedIntValue.kt
  8. 310
      deob/src/main/java/dev/openrs2/deob/transform/ConstantArgTransformer.kt
  9. 429
      deob/src/main/java/dev/openrs2/deob/transform/DummyArgTransformer.kt

@ -13,8 +13,8 @@ import dev.openrs2.deob.transform.BitShiftTransformer
import dev.openrs2.deob.transform.BitwiseOpTransformer
import dev.openrs2.deob.transform.CanvasTransformer
import dev.openrs2.deob.transform.ClassLiteralTransformer
import dev.openrs2.deob.transform.ConstantArgTransformer
import dev.openrs2.deob.transform.CounterTransformer
import dev.openrs2.deob.transform.DummyArgTransformer
import dev.openrs2.deob.transform.DummyLocalTransformer
import dev.openrs2.deob.transform.EmptyClassTransformer
import dev.openrs2.deob.transform.ExceptionTracingTransformer
@ -155,7 +155,7 @@ class Deobfuscator(private val input: Path, private val output: Path) {
FieldOrderTransformer(),
BitwiseOpTransformer(),
RemapTransformer(),
DummyArgTransformer(),
ConstantArgTransformer(),
DummyLocalTransformer(),
UnusedArgTransformer(),
UnusedMethodTransformer(),

@ -0,0 +1,78 @@
package dev.openrs2.deob.analysis
import dev.openrs2.deob.analysis.IntBranchResult.Companion.fromTakenNotTaken
import org.objectweb.asm.Opcodes.IFEQ
import org.objectweb.asm.Opcodes.IFGE
import org.objectweb.asm.Opcodes.IFGT
import org.objectweb.asm.Opcodes.IFLE
import org.objectweb.asm.Opcodes.IFLT
import org.objectweb.asm.Opcodes.IFNE
import org.objectweb.asm.Opcodes.IF_ICMPEQ
import org.objectweb.asm.Opcodes.IF_ICMPGE
import org.objectweb.asm.Opcodes.IF_ICMPGT
import org.objectweb.asm.Opcodes.IF_ICMPLE
import org.objectweb.asm.Opcodes.IF_ICMPLT
import org.objectweb.asm.Opcodes.IF_ICMPNE
object IntBranch {
fun evaluateUnary(opcode: Int, values: Set<Int>): IntBranchResult {
require(values.isNotEmpty())
var taken = 0
var notTaken = 0
for (v in values) {
if (evaluateUnary(opcode, v)) {
taken++
} else {
notTaken++
}
}
return fromTakenNotTaken(taken, notTaken)
}
private fun evaluateUnary(opcode: Int, value: Int): Boolean {
return when (opcode) {
IFEQ -> value == 0
IFNE -> value != 0
IFLT -> value < 0
IFGE -> value >= 0
IFGT -> value > 0
IFLE -> value <= 0
else -> throw IllegalArgumentException("Opcode $opcode is not a unary conditional branch instruction")
}
}
fun evaluateBinary(opcode: Int, values1: Set<Int>, values2: Set<Int>): IntBranchResult {
require(values1.isNotEmpty())
require(values2.isNotEmpty())
var taken = 0
var notTaken = 0
for (v1 in values1) {
for (v2 in values2) {
if (evaluateBinary(opcode, v1, v2)) {
taken++
} else {
notTaken++
}
}
}
return fromTakenNotTaken(taken, notTaken)
}
private fun evaluateBinary(opcode: Int, value1: Int, value2: Int): Boolean {
return when (opcode) {
IF_ICMPEQ -> value1 == value2
IF_ICMPNE -> value1 != value2
IF_ICMPLT -> value1 < value2
IF_ICMPGE -> value1 >= value2
IF_ICMPGT -> value1 > value2
IF_ICMPLE -> value1 <= value2
else -> throw IllegalArgumentException("Opcode $opcode is not a binary conditional branch instruction")
}
}
}

@ -0,0 +1,17 @@
package dev.openrs2.deob.analysis
enum class IntBranchResult {
ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN;
companion object {
fun fromTakenNotTaken(taken: Int, notTaken: Int): IntBranchResult {
require(taken != 0 || notTaken != 0)
return when {
taken == 0 -> NEVER_TAKEN
notTaken == 0 -> ALWAYS_TAKEN
else -> UNKNOWN
}
}
}
}

@ -8,40 +8,33 @@ import org.objectweb.asm.tree.IincInsnNode
import org.objectweb.asm.tree.analysis.BasicInterpreter
import org.objectweb.asm.tree.analysis.Interpreter
class IntInterpreter(private val parameters: Array<Set<Int>?>?) : Interpreter<IntValue>(Opcodes.ASM8) {
class IntInterpreter(private val args: Array<IntValueSet>) : Interpreter<IntValue>(Opcodes.ASM8) {
private val basicInterpreter = BasicInterpreter()
override fun newValue(type: Type?): IntValue? {
val basicValue = basicInterpreter.newValue(type) ?: return null
return IntValue.Unknown(basicValue)
return IntValue(basicValue)
}
override fun newParameterValue(isInstanceMethod: Boolean, local: Int, type: Type): IntValue {
val basicValue = basicInterpreter.newParameterValue(isInstanceMethod, local, type)
if (parameters != null) {
val parameterIndex = when {
isInstanceMethod && local == 0 -> return IntValue.Unknown(basicValue)
isInstanceMethod -> local - 1
else -> local
}
val parameter = parameters[parameterIndex]
if (parameter != null) {
return IntValue.Constant(basicValue, parameter)
}
val index = when {
isInstanceMethod && local == 0 -> return IntValue(basicValue)
isInstanceMethod -> local - 1
else -> local
}
return IntValue.Unknown(basicValue)
return IntValue(basicValue, args[index])
}
override fun newOperation(insn: AbstractInsnNode): IntValue {
val basicValue = basicInterpreter.newOperation(insn)
val v = insn.intConstant
return if (v != null) {
IntValue.Constant(basicValue, v)
IntValue(basicValue, IntValueSet.singleton(v))
} else {
IntValue.Unknown(basicValue)
IntValue(basicValue)
}
}
@ -52,48 +45,48 @@ class IntInterpreter(private val parameters: Array<Set<Int>?>?) : Interpreter<In
override fun unaryOperation(insn: AbstractInsnNode, value: IntValue): IntValue? {
val basicValue = basicInterpreter.unaryOperation(insn, value.basicValue) ?: return null
if (value !is IntValue.Constant) {
return IntValue.Unknown(basicValue)
if (value.set !is IntValueSet.Constant) {
return IntValue(basicValue)
}
val set = mutableSetOf<Int>()
for (v in value.values) {
for (v in value.set.values) {
val result = when {
insn.opcode == Opcodes.INEG -> -v
insn is IincInsnNode -> v + insn.incr
insn.opcode == Opcodes.I2B -> v.toByte().toInt()
insn.opcode == Opcodes.I2C -> v.toChar().toInt()
insn.opcode == Opcodes.I2S -> v.toShort().toInt()
else -> return IntValue.Unknown(basicValue)
else -> return IntValue(basicValue)
}
set.add(result)
}
return IntValue.Constant(basicValue, set)
return IntValue(basicValue, IntValueSet.Constant(set))
}
override fun binaryOperation(insn: AbstractInsnNode, value1: IntValue, value2: IntValue): IntValue? {
val basicValue = basicInterpreter.binaryOperation(insn, value1.basicValue, value2.basicValue) ?: return null
if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) {
return IntValue.Unknown(basicValue)
if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) {
return IntValue(basicValue)
}
val set = mutableSetOf<Int>()
for (v1 in value1.values) {
for (v2 in value2.values) {
for (v1 in value1.set.values) {
for (v2 in value2.set.values) {
val result = when (insn.opcode) {
Opcodes.IADD -> v1 + v2
Opcodes.ISUB -> v1 - v2
Opcodes.IMUL -> v1 * v2
Opcodes.IDIV -> {
if (v2 == 0) {
return IntValue.Unknown(basicValue)
return IntValue(basicValue)
}
v1 / v2
}
Opcodes.IREM -> {
if (v2 == 0) {
return IntValue.Unknown(basicValue)
return IntValue(basicValue)
}
v1 % v2
}
@ -103,12 +96,12 @@ class IntInterpreter(private val parameters: Array<Set<Int>?>?) : Interpreter<In
Opcodes.IAND -> v1 and v2
Opcodes.IOR -> v1 or v2
Opcodes.IXOR -> v1 xor v2
else -> return IntValue.Unknown(basicValue)
else -> return IntValue(basicValue)
}
set.add(result)
}
}
return IntValue.Constant(basicValue, set)
return IntValue(basicValue, IntValueSet.Constant(set))
}
override fun ternaryOperation(
@ -120,13 +113,13 @@ class IntInterpreter(private val parameters: Array<Set<Int>?>?) : Interpreter<In
val basicValue =
basicInterpreter.ternaryOperation(insn, value1.basicValue, value2.basicValue, value3.basicValue)
?: return null
return IntValue.Unknown(basicValue)
return IntValue(basicValue)
}
override fun naryOperation(insn: AbstractInsnNode, values: List<IntValue>): IntValue? {
val args = values.map(IntValue::basicValue)
val basicValue = basicInterpreter.naryOperation(insn, args) ?: return null
return IntValue.Unknown(basicValue)
return IntValue(basicValue)
}
override fun returnOperation(insn: AbstractInsnNode, value: IntValue, expected: IntValue) {
@ -140,15 +133,15 @@ class IntInterpreter(private val parameters: Array<Set<Int>?>?) : Interpreter<In
return value1
}
if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) {
return IntValue.Unknown(basicValue)
if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) {
return IntValue(basicValue)
}
val set = value1.values union value2.values
return if (set.size > MAX_TRACKED_VALUES) {
IntValue.Unknown(basicValue)
val set = value1.set union value2.set
return if (set is IntValueSet.Constant && set.values.size > MAX_TRACKED_VALUES) {
IntValue(basicValue)
} else {
IntValue.Constant(basicValue, set)
IntValue(basicValue, set)
}
}

@ -3,22 +3,7 @@ package dev.openrs2.deob.analysis
import org.objectweb.asm.tree.analysis.BasicValue
import org.objectweb.asm.tree.analysis.Value
sealed class IntValue : Value {
data class Unknown(override val basicValue: BasicValue) : IntValue()
data class Constant(override val basicValue: BasicValue, val values: Set<Int>) : IntValue() {
val singleton: Int?
init {
require(values.isNotEmpty())
singleton = if (values.size == 1) values.first() else null
}
constructor(basicValue: BasicValue, value: Int) : this(basicValue, setOf(value))
}
abstract val basicValue: BasicValue
data class IntValue(val basicValue: BasicValue, val set: IntValueSet = IntValueSet.Unknown) : Value {
override fun getSize(): Int {
return basicValue.size
}

@ -0,0 +1,42 @@
package dev.openrs2.deob.analysis
sealed class IntValueSet {
data class Constant(val values: Set<Int>) : IntValueSet() {
init {
require(values.isNotEmpty())
}
override val singleton: Int?
get() = if (values.size == 1) {
values.first()
} else {
null
}
override fun union(other: IntValueSet): IntValueSet {
return if (other is Constant) {
Constant(values union other.values)
} else {
Unknown
}
}
}
object Unknown : IntValueSet() {
override val singleton: Int?
get() = null
override fun union(other: IntValueSet): IntValueSet {
return Unknown
}
}
abstract val singleton: Int?
abstract infix fun union(other: IntValueSet): IntValueSet
companion object {
fun singleton(value: Int): IntValueSet {
return Constant(setOf(value))
}
}
}

@ -1,6 +0,0 @@
package dev.openrs2.deob.analysis
import dev.openrs2.asm.MemberRef
import dev.openrs2.util.collect.DisjointSet
data class SourcedIntValue(val source: DisjointSet.Partition<MemberRef>, val intValue: IntValue)

@ -0,0 +1,310 @@
package dev.openrs2.deob.transform
import com.github.michaelbull.logging.InlineLogger
import dev.openrs2.asm.MemberRef
import dev.openrs2.asm.classpath.ClassPath
import dev.openrs2.asm.classpath.Library
import dev.openrs2.asm.createIntConstant
import dev.openrs2.asm.deleteExpression
import dev.openrs2.asm.hasCode
import dev.openrs2.asm.intConstant
import dev.openrs2.asm.nextReal
import dev.openrs2.asm.pure
import dev.openrs2.asm.replaceExpression
import dev.openrs2.asm.stackMetadata
import dev.openrs2.asm.transform.Transformer
import dev.openrs2.deob.ArgRef
import dev.openrs2.deob.analysis.IntBranch
import dev.openrs2.deob.analysis.IntBranchResult.ALWAYS_TAKEN
import dev.openrs2.deob.analysis.IntBranchResult.NEVER_TAKEN
import dev.openrs2.deob.analysis.IntInterpreter
import dev.openrs2.deob.analysis.IntValueSet
import dev.openrs2.deob.remap.TypedRemapper
import dev.openrs2.util.collect.DisjointSet
import dev.openrs2.util.collect.removeFirstOrNull
import org.objectweb.asm.Opcodes.GOTO
import org.objectweb.asm.Opcodes.IFEQ
import org.objectweb.asm.Opcodes.IFGE
import org.objectweb.asm.Opcodes.IFGT
import org.objectweb.asm.Opcodes.IFLE
import org.objectweb.asm.Opcodes.IFLT
import org.objectweb.asm.Opcodes.IFNE
import org.objectweb.asm.Opcodes.IF_ICMPEQ
import org.objectweb.asm.Opcodes.IF_ICMPGE
import org.objectweb.asm.Opcodes.IF_ICMPGT
import org.objectweb.asm.Opcodes.IF_ICMPLE
import org.objectweb.asm.Opcodes.IF_ICMPLT
import org.objectweb.asm.Opcodes.IF_ICMPNE
import org.objectweb.asm.Type
import org.objectweb.asm.tree.AbstractInsnNode
import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.JumpInsnNode
import org.objectweb.asm.tree.MethodInsnNode
import org.objectweb.asm.tree.MethodNode
import org.objectweb.asm.tree.analysis.Analyzer
class ConstantArgTransformer : Transformer() {
private val pendingMethods = LinkedHashSet<MemberRef>()
private val arglessMethods = mutableSetOf<DisjointSet.Partition<MemberRef>>()
private val argValues = mutableMapOf<ArgRef, IntValueSet>()
private lateinit var inheritedMethodSets: DisjointSet<MemberRef>
private var branchesSimplified = 0
private var constantsInlined = 0
override fun preTransform(classPath: ClassPath) {
pendingMethods.clear()
arglessMethods.clear()
argValues.clear()
inheritedMethodSets = classPath.createInheritedMethodSets()
branchesSimplified = 0
constantsInlined = 0
queueEntryPoints(classPath)
while (true) {
val method = pendingMethods.removeFirstOrNull() ?: break
analyzeMethod(classPath, method)
}
}
private fun queueEntryPoints(classPath: ClassPath) {
for (partition in inheritedMethodSets) {
/*
* The EXCLUDED_METHODS set roughly matches up with the methods we
* want to consider as entry points. It isn't perfect - it counts
* every <init> method as an entry point, but strictly speaking we
* only need to count <init> methods invoked with reflection as
* entry points (like VisibilityTransformer). However, it makes no
* difference in this case, as the deobfuscator does not add dummy
* constant arguments to constructors.
*/
val excluded = partition.first().name in TypedRemapper.EXCLUDED_METHODS
val overridesDependency = partition.any { classPath[it.owner]!!.dependency }
if (excluded || overridesDependency) {
pendingMethods.addAll(partition)
}
}
}
private fun analyzeMethod(classPath: ClassPath, ref: MemberRef) {
// find ClassNode/MethodNode
val owner = classPath.getNode(ref.owner) ?: return
val originalMethod = owner.methods.singleOrNull { it.name == ref.name && it.desc == ref.desc } ?: return
if (!originalMethod.hasCode()) {
return
}
/*
* Clone the method - we don't want to mutate it permanently until the
* final pass, as we might discover more routes through the call graph
* later which reduce the number of branches we can simplify.
*/
val method = MethodNode(
originalMethod.access,
originalMethod.name,
originalMethod.desc,
originalMethod.signature,
originalMethod.exceptions?.toTypedArray()
)
originalMethod.accept(method)
// find existing constant arguments
val args = getArgs(ref)
// simplify branches
simplifyBranches(owner, method, args)
/*
* Record new constant arguments in method calls. This re-runs the
* analyzer rather than re-using the frames from simplifyBranches. This
* ensures we ignore branches that always evaluate to false, preventing
* us from recording constant arguments found in dummy calls (which
* would prevent us from removing further dummy calls/branches).
*/
addArgValues(owner, method, args)
}
private fun getArgs(ref: MemberRef): Array<IntValueSet> {
val partition = inheritedMethodSets[ref]!!
val size = Type.getArgumentTypes(ref.desc).sumBy { it.size }
return Array(size) { i -> argValues[ArgRef(partition, i)] ?: IntValueSet.Unknown }
}
private fun addArgValues(owner: ClassNode, method: MethodNode, args: Array<IntValueSet>) {
val analyzer = Analyzer(IntInterpreter(args))
val frames = analyzer.analyze(owner.name, method)
for ((i, frame) in frames.withIndex()) {
if (frame == null) {
continue
}
val insn = method.instructions[i]
if (insn !is MethodInsnNode) {
continue
}
val invokedMethod = inheritedMethodSets[MemberRef(insn)] ?: continue
val size = Type.getArgumentTypes(insn.desc).size
var index = 0
for (j in 0 until size) {
val value = frame.getStack(frame.stackSize - size + j)
if (addArgValues(ArgRef(invokedMethod, index), value.set)) {
pendingMethods.addAll(invokedMethod)
}
index += value.size
}
if (size == 0 && arglessMethods.add(invokedMethod)) {
pendingMethods.addAll(invokedMethod)
}
}
}
private fun addArgValues(ref: ArgRef, value: IntValueSet): Boolean {
val old = argValues[ref]
val new = if (value.singleton != null) {
if (old != null) {
old union value
} else {
value
}
} else {
IntValueSet.Unknown
}
argValues[ref] = new
return old != new
}
private fun simplifyBranches(owner: ClassNode, method: MethodNode, args: Array<IntValueSet>): Int {
val analyzer = Analyzer(IntInterpreter(args))
val frames = analyzer.analyze(owner.name, method)
val alwaysTakenBranches = mutableListOf<JumpInsnNode>()
val neverTakenBranches = mutableListOf<JumpInsnNode>()
frame@ for ((i, frame) in frames.withIndex()) {
if (frame == null) {
continue
}
val insn = method.instructions[i]
if (insn !is JumpInsnNode) {
continue
}
when (insn.opcode) {
IFEQ, IFNE, IFLT, IFGE, IFGT, IFLE -> {
val value = frame.getStack(frame.stackSize - 1)
if (value.set !is IntValueSet.Constant) {
continue@frame
}
@Suppress("NON_EXHAUSTIVE_WHEN")
when (IntBranch.evaluateUnary(insn.opcode, value.set.values)) {
ALWAYS_TAKEN -> alwaysTakenBranches += insn
NEVER_TAKEN -> neverTakenBranches += insn
}
}
IF_ICMPEQ, IF_ICMPNE, IF_ICMPLT, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE -> {
val value1 = frame.getStack(frame.stackSize - 2)
val value2 = frame.getStack(frame.stackSize - 1)
if (value1.set !is IntValueSet.Constant || value2.set !is IntValueSet.Constant) {
continue@frame
}
@Suppress("NON_EXHAUSTIVE_WHEN")
when (IntBranch.evaluateBinary(insn.opcode, value1.set.values, value2.set.values)) {
ALWAYS_TAKEN -> alwaysTakenBranches += insn
NEVER_TAKEN -> neverTakenBranches += insn
}
}
}
}
var simplified = 0
for (insn in alwaysTakenBranches) {
val replacement = JumpInsnNode(GOTO, insn.label)
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) {
simplified++
}
}
for (insn in neverTakenBranches) {
if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) {
simplified++
}
}
return simplified
}
private fun inlineConstantArgs(clazz: ClassNode, method: MethodNode, args: Array<IntValueSet>): Int {
val analyzer = Analyzer(IntInterpreter(args))
val frames = analyzer.analyze(clazz.name, method)
val constInsns = mutableMapOf<AbstractInsnNode, Int>()
for ((i, frame) in frames.withIndex()) {
if (frame == null) {
continue
}
val insn = method.instructions[i]
if (insn.intConstant != null) {
// already constant
continue
} else if (!insn.pure) {
// can't replace instructions with a side effect
continue
} else if (insn.stackMetadata().pushes != 1) {
// can't replace instructions pushing more than one result
continue
}
// the value pushed by this instruction is held in the following frame
val nextInsn = insn.nextReal ?: continue
val nextInsnIndex = method.instructions.indexOf(nextInsn)
val nextFrame = frames[nextInsnIndex]
val value = nextFrame.getStack(nextFrame.stackSize - 1)
val singleton = value.set.singleton
if (singleton != null) {
constInsns[insn] = singleton
}
}
var inlined = 0
for ((insn, value) in constInsns) {
if (insn !in method.instructions) {
continue
}
val replacement = createIntConstant(value)
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) {
inlined++
}
}
return inlined
}
override fun transformCode(classPath: ClassPath, library: Library, clazz: ClassNode, method: MethodNode): Boolean {
val args = getArgs(MemberRef(clazz, method))
branchesSimplified += simplifyBranches(clazz, method, args)
constantsInlined += inlineConstantArgs(clazz, method, args)
return false
}
override fun postTransform(classPath: ClassPath) {
logger.info { "Simplified $branchesSimplified branches and inlined $constantsInlined constants" }
}
companion object {
private val logger = InlineLogger()
}
}

@ -1,429 +0,0 @@
package dev.openrs2.deob.transform
import com.github.michaelbull.logging.InlineLogger
import com.google.common.collect.HashMultimap
import com.google.common.collect.Multimap
import dev.openrs2.asm.InsnMatcher
import dev.openrs2.asm.MemberRef
import dev.openrs2.asm.classpath.ClassPath
import dev.openrs2.asm.classpath.Library
import dev.openrs2.asm.createIntConstant
import dev.openrs2.asm.deleteExpression
import dev.openrs2.asm.intConstant
import dev.openrs2.asm.nextReal
import dev.openrs2.asm.pure
import dev.openrs2.asm.replaceExpression
import dev.openrs2.asm.stackMetadata
import dev.openrs2.asm.transform.Transformer
import dev.openrs2.deob.ArgRef
import dev.openrs2.deob.analysis.IntInterpreter
import dev.openrs2.deob.analysis.IntValue
import dev.openrs2.deob.analysis.SourcedIntValue
import dev.openrs2.util.collect.DisjointSet
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.tree.AbstractInsnNode
import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.JumpInsnNode
import org.objectweb.asm.tree.MethodInsnNode
import org.objectweb.asm.tree.MethodNode
import org.objectweb.asm.tree.VarInsnNode
import org.objectweb.asm.tree.analysis.Analyzer
class DummyArgTransformer : Transformer() {
private data class ConditionalCall(
val conditionVar: Int,
val conditionOpcode: Int,
val conditionValue: Int?,
val method: DisjointSet.Partition<MemberRef>,
val constArgs: List<Int?>
)
private enum class BranchResult {
ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN;
companion object {
fun fromTakenNotTaken(taken: Int, notTaken: Int): BranchResult {
require(taken != 0 || notTaken != 0)
return when {
taken == 0 -> NEVER_TAKEN
notTaken == 0 -> ALWAYS_TAKEN
else -> UNKNOWN
}
}
}
}
private val argValues: Multimap<ArgRef, SourcedIntValue> = HashMultimap.create()
private val conditionalCalls: Multimap<DisjointSet.Partition<MemberRef>?, ConditionalCall> = HashMultimap.create()
private val constArgs = mutableMapOf<DisjointSet.Partition<MemberRef>, Array<Set<Int>?>>()
private lateinit var inheritedMethodSets: DisjointSet<MemberRef>
private var branchesSimplified = 0
private var constantsInlined = 0
private fun isMutuallyRecursiveDummy(
method: DisjointSet.Partition<MemberRef>,
arg: Int,
source: DisjointSet.Partition<MemberRef>,
value: Int
): Boolean {
for (sourceToMethodCall in conditionalCalls[source]) {
if (sourceToMethodCall.method != method) {
continue
}
for (methodToSourceCall in conditionalCalls[method]) {
if (methodToSourceCall.method != source || methodToSourceCall.conditionVar != arg) {
continue
}
var taken = if (methodToSourceCall.conditionValue != null) {
evaluateBinaryBranch(methodToSourceCall.conditionOpcode, value, methodToSourceCall.conditionValue)
} else {
evaluateUnaryBranch(methodToSourceCall.conditionOpcode, value)
}
if (taken) {
continue
}
val constArg = methodToSourceCall.constArgs[sourceToMethodCall.conditionVar]!!
taken = if (sourceToMethodCall.conditionValue != null) {
evaluateBinaryBranch(
sourceToMethodCall.conditionOpcode,
constArg,
sourceToMethodCall.conditionValue
)
} else {
evaluateUnaryBranch(sourceToMethodCall.conditionOpcode, constArg)
}
if (taken) {
continue
}
return true
}
}
return false
}
private fun union(
method: DisjointSet.Partition<MemberRef>,
arg: Int,
intValues: Collection<SourcedIntValue>
): Set<Int>? {
val set = mutableSetOf<Int>()
for ((source, intValue) in intValues) {
if (intValue !is IntValue.Constant) {
return null
}
if (source == method) {
continue
}
if (intValue.singleton != null) {
if (isMutuallyRecursiveDummy(method, arg, source, intValue.singleton)) {
continue
}
}
set.addAll(intValue.values)
}
return if (set.isEmpty()) {
null
} else {
set
}
}
override fun preTransform(classPath: ClassPath) {
inheritedMethodSets = classPath.createInheritedMethodSets()
branchesSimplified = 0
constantsInlined = 0
}
override fun prePass(classPath: ClassPath): Boolean {
argValues.clear()
conditionalCalls.clear()
return false
}
override fun transformCode(
classPath: ClassPath,
library: Library,
clazz: ClassNode,
method: MethodNode
): Boolean {
val parentMethod = inheritedMethodSets[MemberRef(clazz, method)]!!
val stores = BooleanArray(method.maxLocals)
for (insn in method.instructions) {
if (insn is VarInsnNode && insn.opcode == Opcodes.ISTORE) {
stores[insn.`var`] = true
}
}
for (match in CONDITIONAL_CALL_MATCHER.match(method)) {
var matchIndex = 0
val load = match[matchIndex++] as VarInsnNode
if (stores[load.`var`]) {
continue
}
var callerSlots = Type.getArgumentsAndReturnSizes(method.desc) shr 2
if (method.access and Opcodes.ACC_STATIC != 0) {
callerSlots--
}
if (load.`var` >= callerSlots) {
continue
}
val conditionValue: Int?
var conditionOpcode = match[matchIndex].opcode
if (conditionOpcode == Opcodes.IFEQ || conditionOpcode == Opcodes.IFNE) {
conditionValue = null
matchIndex++
} else {
conditionValue = match[matchIndex++].intConstant
conditionOpcode = match[matchIndex++].opcode
}
val invoke = match[match.size - 1] as MethodInsnNode
var invokeArgCount = Type.getArgumentTypes(invoke.desc).size
if (invoke.opcode != Opcodes.INVOKESTATIC) {
invokeArgCount++
}
val constArgs = arrayOfNulls<Int>(invokeArgCount)
for (i in constArgs.indices) {
val insn = match[matchIndex++]
if (insn.opcode == Opcodes.ACONST_NULL) {
matchIndex++
} else {
constArgs[i] = insn.intConstant
}
}
val callee = inheritedMethodSets[MemberRef(invoke)] ?: continue
conditionalCalls.put(
parentMethod,
ConditionalCall(load.`var`, conditionOpcode, conditionValue, callee, constArgs.asList())
)
}
val parameters = constArgs[parentMethod]
val analyzer = Analyzer(IntInterpreter(parameters))
val frames = analyzer.analyze(clazz.name, method)
var changed = false
val alwaysTakenBranches = mutableListOf<JumpInsnNode>()
val neverTakenBranches = mutableListOf<JumpInsnNode>()
val constInsns = mutableMapOf<AbstractInsnNode, Int>()
frame@ for ((i, frame) in frames.withIndex()) {
if (frame == null) {
continue
}
val stackSize = frame.stackSize
val insn = method.instructions[i]
when (insn.opcode) {
Opcodes.INVOKEVIRTUAL, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEINTERFACE -> {
val invoke = insn as MethodInsnNode
val invokedMethod = inheritedMethodSets[MemberRef(invoke)] ?: continue@frame
val args = Type.getArgumentTypes(invoke.desc).size
var k = 0
for (j in 0 until args) {
val arg = frame.getStack(stackSize - args + j)
argValues.put(ArgRef(invokedMethod, k), SourcedIntValue(parentMethod, arg))
k += arg.size
}
}
Opcodes.IFEQ, Opcodes.IFNE -> {
val value = frame.getStack(stackSize - 1)
if (value !is IntValue.Constant) {
continue@frame
}
val result = evaluateUnaryBranch(insn.opcode, value.values)
@Suppress("NON_EXHAUSTIVE_WHEN")
when (result) {
BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode)
BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode)
}
}
Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPLT, Opcodes.IF_ICMPGE, Opcodes.IF_ICMPGT,
Opcodes.IF_ICMPLE -> {
val value1 = frame.getStack(stackSize - 2)
val value2 = frame.getStack(stackSize - 1)
if (value1 !is IntValue.Constant || value2 !is IntValue.Constant) {
continue@frame
}
val result = evaluateBinaryBranch(insn.opcode, value1.values, value2.values)
@Suppress("NON_EXHAUSTIVE_WHEN")
when (result) {
BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode)
BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode)
}
}
else -> {
if (!insn.pure || insn.intConstant != null) {
continue@frame
}
if (insn.stackMetadata().pushes != 1) {
continue@frame
}
val nextInsn = insn.nextReal ?: continue@frame
val nextInsnIndex = method.instructions.indexOf(nextInsn)
val nextFrame = frames[nextInsnIndex]
val value = nextFrame.getStack(nextFrame.stackSize - 1)
if (value is IntValue.Constant && value.singleton != null) {
constInsns[insn] = value.singleton
}
}
}
}
for (insn in alwaysTakenBranches) {
val replacement = JumpInsnNode(Opcodes.GOTO, insn.label)
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) {
branchesSimplified++
changed = true
}
}
for (insn in neverTakenBranches) {
if (method.instructions.deleteExpression(insn, AbstractInsnNode::pure)) {
branchesSimplified++
changed = true
}
}
for ((insn, value) in constInsns) {
if (insn !in method.instructions) {
continue
}
val replacement = createIntConstant(value)
if (method.instructions.replaceExpression(insn, replacement, AbstractInsnNode::pure)) {
constantsInlined++
changed = true
}
}
return changed
}
override fun postPass(classPath: ClassPath): Boolean {
for (method in inheritedMethodSets) {
val args = (Type.getArgumentsAndReturnSizes(method.first().desc) shr 2) - 1
var allUnknown = true
val parameters = arrayOfNulls<Set<Int>?>(args)
for (i in 0 until args) {
val parameter = union(method, i, argValues[ArgRef(method, i)])
if (parameter != null) {
allUnknown = false
}
parameters[i] = parameter
}
if (allUnknown) {
constArgs.remove(method)
} else {
constArgs[method] = parameters
}
}
return false
}
override fun postTransform(classPath: ClassPath) {
logger.info { "Simplified $branchesSimplified dummy branches and inlined $constantsInlined constants" }
}
companion object {
private val logger = InlineLogger()
private val CONDITIONAL_CALL_MATCHER = InsnMatcher.compile(
"""
ILOAD
(IFEQ | IFNE |
(ICONST | BIPUSH | SIPUSH | LDC)
(IF_ICMPEQ | IF_ICMPNE | IF_ICMPLT | IF_ICMPGE | IF_ICMPGT | IF_ICMPLE)
)
ALOAD?
(ICONST | FCONST | DCONST | BIPUSH | SIPUSH | LDC | ACONST_NULL CHECKCAST)+
(INVOKEVIRTUAL | INVOKESTATIC | INVOKEINTERFACE)
"""
)
private fun evaluateUnaryBranch(opcode: Int, values: Set<Int>): BranchResult {
require(values.isNotEmpty())
var taken = 0
var notTaken = 0
for (v in values) {
if (evaluateUnaryBranch(opcode, v)) {
taken++
} else {
notTaken++
}
}
return BranchResult.fromTakenNotTaken(taken, notTaken)
}
private fun evaluateUnaryBranch(opcode: Int, value: Int): Boolean {
return when (opcode) {
Opcodes.IFEQ -> value == 0
Opcodes.IFNE -> value != 0
else -> throw IllegalArgumentException()
}
}
private fun evaluateBinaryBranch(opcode: Int, values1: Set<Int>, values2: Set<Int>): BranchResult {
require(values1.isNotEmpty() && values2.isNotEmpty())
var taken = 0
var notTaken = 0
for (v1 in values1) {
for (v2 in values2) {
if (evaluateBinaryBranch(opcode, v1, v2)) {
taken++
} else {
notTaken++
}
}
}
return BranchResult.fromTakenNotTaken(taken, notTaken)
}
private fun evaluateBinaryBranch(opcode: Int, value1: Int, value2: Int): Boolean {
return when (opcode) {
Opcodes.IF_ICMPEQ -> value1 == value2
Opcodes.IF_ICMPNE -> value1 != value2
Opcodes.IF_ICMPLT -> value1 < value2
Opcodes.IF_ICMPGE -> value1 >= value2
Opcodes.IF_ICMPGT -> value1 > value2
Opcodes.IF_ICMPLE -> value1 <= value2
else -> throw IllegalArgumentException()
}
}
}
}
Loading…
Cancel
Save