forked from openrs2/openrs2
parent
6ad99645dc
commit
5fa44c9016
@ -1,472 +0,0 @@ |
||||
package dev.openrs2.deob.transform; |
||||
|
||||
import java.util.ArrayList; |
||||
import java.util.Collection; |
||||
import java.util.HashMap; |
||||
import java.util.Map; |
||||
import java.util.Set; |
||||
|
||||
import com.google.common.base.MoreObjects; |
||||
import com.google.common.base.Preconditions; |
||||
import com.google.common.collect.HashMultimap; |
||||
import com.google.common.collect.ImmutableSet; |
||||
import com.google.common.collect.Multimap; |
||||
import dev.openrs2.asm.InsnListUtilsKt; |
||||
import dev.openrs2.asm.InsnMatcher; |
||||
import dev.openrs2.asm.InsnNodeUtilsKt; |
||||
import dev.openrs2.asm.MemberRef; |
||||
import dev.openrs2.asm.StackMetadataKt; |
||||
import dev.openrs2.asm.classpath.ClassPath; |
||||
import dev.openrs2.asm.classpath.Library; |
||||
import dev.openrs2.asm.transform.Transformer; |
||||
import dev.openrs2.common.collect.DisjointSet; |
||||
import dev.openrs2.deob.ArgRef; |
||||
import dev.openrs2.deob.analysis.IntInterpreter; |
||||
import dev.openrs2.deob.analysis.SourcedIntValue; |
||||
import org.objectweb.asm.Opcodes; |
||||
import org.objectweb.asm.Type; |
||||
import org.objectweb.asm.tree.AbstractInsnNode; |
||||
import org.objectweb.asm.tree.ClassNode; |
||||
import org.objectweb.asm.tree.JumpInsnNode; |
||||
import org.objectweb.asm.tree.MethodInsnNode; |
||||
import org.objectweb.asm.tree.MethodNode; |
||||
import org.objectweb.asm.tree.VarInsnNode; |
||||
import org.objectweb.asm.tree.analysis.Analyzer; |
||||
import org.objectweb.asm.tree.analysis.AnalyzerException; |
||||
import org.slf4j.Logger; |
||||
import org.slf4j.LoggerFactory; |
||||
|
||||
public final class DummyArgTransformer extends Transformer { |
||||
private static final Logger logger = LoggerFactory.getLogger(DummyArgTransformer.class); |
||||
|
||||
private static final InsnMatcher CONDITIONAL_CALL_MATCHER = InsnMatcher.compile("ILOAD (IFEQ | IFNE | (ICONST | BIPUSH | SIPUSH | LDC) (IF_ICMPEQ | IF_ICMPNE | IF_ICMPLT | IF_ICMPGE | IF_ICMPGT | IF_ICMPLE)) ALOAD? (ICONST | FCONST | DCONST | BIPUSH | SIPUSH | LDC | ACONST_NULL CHECKCAST)+ (INVOKEVIRTUAL | INVOKESTATIC | INVOKEINTERFACE)"); |
||||
|
||||
private static final class ConditionalCall { |
||||
private final int conditionVar, conditionOpcode; |
||||
private final Integer conditionValue; |
||||
private final DisjointSet.Partition<MemberRef> method; |
||||
private final Integer[] constArgs; |
||||
|
||||
public ConditionalCall(int conditionVar, int conditionOpcode, Integer conditionValue, DisjointSet.Partition<MemberRef> method, Integer[] constArgs) { |
||||
this.conditionVar = conditionVar; |
||||
this.conditionOpcode = conditionOpcode; |
||||
this.conditionValue = conditionValue; |
||||
this.method = method; |
||||
this.constArgs = constArgs; |
||||
} |
||||
|
||||
@Override |
||||
public String toString() { |
||||
return MoreObjects.toStringHelper(this) |
||||
.add("conditionVar", conditionVar) |
||||
.add("conditionOpcode", conditionOpcode) |
||||
.add("conditionValue", conditionValue) |
||||
.add("method", method) |
||||
.add("constArgs", constArgs) |
||||
.toString(); |
||||
} |
||||
} |
||||
|
||||
private enum BranchResult { |
||||
ALWAYS_TAKEN, |
||||
NEVER_TAKEN, |
||||
UNKNOWN; |
||||
|
||||
public static BranchResult fromTakenNotTaken(int taken, int notTaken) { |
||||
Preconditions.checkArgument(taken != 0 || notTaken != 0); |
||||
|
||||
if (taken == 0) { |
||||
return NEVER_TAKEN; |
||||
} else if (notTaken == 0) { |
||||
return ALWAYS_TAKEN; |
||||
} else { |
||||
return UNKNOWN; |
||||
} |
||||
} |
||||
} |
||||
|
||||
private static BranchResult evaluateUnaryBranch(int opcode, Set<Integer> values) { |
||||
Preconditions.checkArgument(!values.isEmpty()); |
||||
|
||||
int taken = 0, notTaken = 0; |
||||
for (var v : values) { |
||||
if (evaluateUnaryBranch(opcode, v)) { |
||||
taken++; |
||||
} else { |
||||
notTaken++; |
||||
} |
||||
} |
||||
|
||||
return BranchResult.fromTakenNotTaken(taken, notTaken); |
||||
} |
||||
|
||||
private static boolean evaluateUnaryBranch(int opcode, int value) { |
||||
switch (opcode) { |
||||
case Opcodes.IFEQ: |
||||
return value == 0; |
||||
case Opcodes.IFNE: |
||||
return value != 0; |
||||
default: |
||||
throw new IllegalArgumentException(); |
||||
} |
||||
} |
||||
|
||||
private static BranchResult evaluateBinaryBranch(int opcode, Set<Integer> values1, Set<Integer> values2) { |
||||
Preconditions.checkArgument(!values1.isEmpty() && !values2.isEmpty()); |
||||
|
||||
int taken = 0, notTaken = 0; |
||||
for (var v1 : values1) { |
||||
for (var v2 : values2) { |
||||
if (evaluateBinaryBranch(opcode, v1, v2)) { |
||||
taken++; |
||||
} else { |
||||
notTaken++; |
||||
} |
||||
} |
||||
} |
||||
|
||||
return BranchResult.fromTakenNotTaken(taken, notTaken); |
||||
} |
||||
|
||||
private static boolean evaluateBinaryBranch(int opcode, int value1, int value2) { |
||||
switch (opcode) { |
||||
case Opcodes.IF_ICMPEQ: |
||||
return value1 == value2; |
||||
case Opcodes.IF_ICMPNE: |
||||
return value1 != value2; |
||||
case Opcodes.IF_ICMPLT: |
||||
return value1 < value2; |
||||
case Opcodes.IF_ICMPGE: |
||||
return value1 >= value2; |
||||
case Opcodes.IF_ICMPGT: |
||||
return value1 > value2; |
||||
case Opcodes.IF_ICMPLE: |
||||
return value1 <= value2; |
||||
default: |
||||
throw new IllegalArgumentException(); |
||||
} |
||||
} |
||||
|
||||
private final Multimap<ArgRef, SourcedIntValue> argValues = HashMultimap.create(); |
||||
private final Multimap<DisjointSet.Partition<MemberRef>, ConditionalCall> conditionalCalls = HashMultimap.create(); |
||||
private final Map<DisjointSet.Partition<MemberRef>, ImmutableSet<Integer>[]> constArgs = new HashMap<>(); |
||||
private DisjointSet<MemberRef> inheritedMethodSets; |
||||
private int branchesSimplified, constantsInlined; |
||||
|
||||
private boolean isMutuallyRecursiveDummy(DisjointSet.Partition<MemberRef> method, int arg, DisjointSet.Partition<MemberRef> source, int value) { |
||||
for (var sourceToMethodCall : conditionalCalls.get(source)) { |
||||
if (!sourceToMethodCall.method.equals(method)) { |
||||
continue; |
||||
} |
||||
|
||||
for (var methodToSourceCall : conditionalCalls.get(method)) { |
||||
if (!methodToSourceCall.method.equals(source)) { |
||||
continue; |
||||
} |
||||
|
||||
if (methodToSourceCall.conditionVar != arg) { |
||||
continue; |
||||
} |
||||
|
||||
boolean taken; |
||||
if (methodToSourceCall.conditionValue != null) { |
||||
taken = evaluateBinaryBranch(methodToSourceCall.conditionOpcode, value, methodToSourceCall.conditionValue); |
||||
} else { |
||||
taken = evaluateUnaryBranch(methodToSourceCall.conditionOpcode, value); |
||||
} |
||||
|
||||
if (taken) { |
||||
continue; |
||||
} |
||||
|
||||
if (sourceToMethodCall.conditionValue != null) { |
||||
taken = evaluateBinaryBranch(sourceToMethodCall.conditionOpcode, methodToSourceCall.constArgs[sourceToMethodCall.conditionVar], sourceToMethodCall.conditionValue); |
||||
} else { |
||||
taken = evaluateUnaryBranch(sourceToMethodCall.conditionOpcode, methodToSourceCall.constArgs[sourceToMethodCall.conditionVar]); |
||||
} |
||||
|
||||
if (taken) { |
||||
continue; |
||||
} |
||||
|
||||
return true; |
||||
} |
||||
} |
||||
|
||||
return false; |
||||
} |
||||
|
||||
private ImmutableSet<Integer> union(DisjointSet.Partition<MemberRef> method, int arg, Collection<SourcedIntValue> intValues) { |
||||
var builder = ImmutableSet.<Integer>builder(); |
||||
|
||||
for (var value : intValues) { |
||||
var intValue = value.getIntValue(); |
||||
if (intValue.isUnknown()) { |
||||
return null; |
||||
} |
||||
|
||||
var source = value.getSource(); |
||||
if (source.equals(method)) { |
||||
continue; |
||||
} |
||||
|
||||
if (intValue.isSingleConstant()) { |
||||
if (isMutuallyRecursiveDummy(method, arg, source, intValue.getIntValue())) { |
||||
continue; |
||||
} |
||||
} |
||||
|
||||
builder.addAll(intValue.getIntValues()); |
||||
} |
||||
|
||||
var set = builder.build(); |
||||
if (set.isEmpty()) { |
||||
return null; |
||||
} |
||||
|
||||
return set; |
||||
} |
||||
|
||||
@Override |
||||
protected void preTransform(ClassPath classPath) { |
||||
inheritedMethodSets = classPath.createInheritedMethodSets(); |
||||
branchesSimplified = 0; |
||||
constantsInlined = 0; |
||||
} |
||||
|
||||
@Override |
||||
protected void prePass(ClassPath classPath) { |
||||
argValues.clear(); |
||||
conditionalCalls.clear(); |
||||
} |
||||
|
||||
@Override |
||||
protected boolean transformCode(ClassPath classPath, Library library, ClassNode clazz, MethodNode method) throws AnalyzerException { |
||||
var parentMethod = inheritedMethodSets.get(new MemberRef(clazz, method)); |
||||
|
||||
var stores = new boolean[method.maxLocals]; |
||||
|
||||
for (AbstractInsnNode insn : method.instructions) { |
||||
var opcode = insn.getOpcode(); |
||||
if (opcode != Opcodes.ISTORE) { |
||||
continue; |
||||
} |
||||
|
||||
var store = (VarInsnNode) insn; |
||||
stores[store.var] = true; |
||||
} |
||||
|
||||
CONDITIONAL_CALL_MATCHER.match(method).forEach(match -> { |
||||
var matchIndex = 0; |
||||
var load = (VarInsnNode) match.get(matchIndex++); |
||||
if (stores[load.var]) { |
||||
return; |
||||
} |
||||
|
||||
var callerSlots = Type.getArgumentsAndReturnSizes(method.desc) >> 2; |
||||
if ((method.access & Opcodes.ACC_STATIC) != 0) { |
||||
callerSlots++; |
||||
} |
||||
|
||||
if (load.var >= callerSlots) { |
||||
return; |
||||
} |
||||
|
||||
Integer conditionValue; |
||||
var conditionOpcode = match.get(matchIndex).getOpcode(); |
||||
if (conditionOpcode == Opcodes.IFEQ || conditionOpcode == Opcodes.IFNE) { |
||||
conditionValue = null; |
||||
matchIndex++; |
||||
} else { |
||||
conditionValue = InsnNodeUtilsKt.getIntConstant(match.get(matchIndex++)); |
||||
conditionOpcode = match.get(matchIndex++).getOpcode(); |
||||
} |
||||
|
||||
var invoke = (MethodInsnNode) match.get(match.size() - 1); |
||||
|
||||
var invokeArgTypes = Type.getArgumentTypes(invoke.desc).length; |
||||
var constArgs = new Integer[invokeArgTypes]; |
||||
|
||||
if (invoke.getOpcode() != Opcodes.INVOKESTATIC) { |
||||
matchIndex++; |
||||
} |
||||
|
||||
for (int i = 0; i < constArgs.length; i++) { |
||||
var insn = match.get(matchIndex++); |
||||
if (insn.getOpcode() == Opcodes.ACONST_NULL) { |
||||
matchIndex++; |
||||
} else { |
||||
constArgs[i] = InsnNodeUtilsKt.getIntConstant(insn); |
||||
} |
||||
} |
||||
|
||||
var callee = inheritedMethodSets.get(new MemberRef(invoke)); |
||||
if (callee == null) { |
||||
return; |
||||
} |
||||
conditionalCalls.put(parentMethod, new ConditionalCall(load.var, conditionOpcode, conditionValue, callee, constArgs)); |
||||
}); |
||||
|
||||
var parameters = constArgs.get(parentMethod); |
||||
|
||||
var analyzer = new Analyzer<>(new IntInterpreter(parameters)); |
||||
var frames = analyzer.analyze(clazz.name, method); |
||||
|
||||
var changed = false; |
||||
|
||||
var alwaysTakenBranches = new ArrayList<JumpInsnNode>(); |
||||
var neverTakenBranches = new ArrayList<JumpInsnNode>(); |
||||
var constInsns = new HashMap<AbstractInsnNode, Integer>(); |
||||
|
||||
for (var i = 0; i < frames.length; i++) { |
||||
var frame = frames[i]; |
||||
if (frame == null) { |
||||
continue; |
||||
} |
||||
|
||||
var stackSize = frame.getStackSize(); |
||||
|
||||
var insn = method.instructions.get(i); |
||||
switch (insn.getOpcode()) { |
||||
case Opcodes.INVOKEVIRTUAL: |
||||
case Opcodes.INVOKESPECIAL: |
||||
case Opcodes.INVOKESTATIC: |
||||
case Opcodes.INVOKEINTERFACE: |
||||
var invoke = (MethodInsnNode) insn; |
||||
var invokedMethod = inheritedMethodSets.get(new MemberRef(invoke)); |
||||
if (invokedMethod == null) { |
||||
continue; |
||||
} |
||||
|
||||
var args = Type.getArgumentTypes(invoke.desc).length; |
||||
for (int j = 0, k = 0; j < args; j++) { |
||||
var arg = frame.getStack(stackSize - args + j); |
||||
argValues.put(new ArgRef(invokedMethod, k), new SourcedIntValue(parentMethod, arg)); |
||||
k += arg.getSize(); |
||||
} |
||||
break; |
||||
case Opcodes.IFEQ: |
||||
case Opcodes.IFNE: |
||||
var value = frame.getStack(stackSize - 1); |
||||
if (value.isUnknown()) { |
||||
continue; |
||||
} |
||||
|
||||
var result = evaluateUnaryBranch(insn.getOpcode(), value.getIntValues()); |
||||
switch (result) { |
||||
case ALWAYS_TAKEN: |
||||
alwaysTakenBranches.add((JumpInsnNode) insn); |
||||
break; |
||||
case NEVER_TAKEN: |
||||
neverTakenBranches.add((JumpInsnNode) insn); |
||||
break; |
||||
} |
||||
break; |
||||
case Opcodes.IF_ICMPEQ: |
||||
case Opcodes.IF_ICMPNE: |
||||
case Opcodes.IF_ICMPLT: |
||||
case Opcodes.IF_ICMPGE: |
||||
case Opcodes.IF_ICMPGT: |
||||
case Opcodes.IF_ICMPLE: |
||||
var value1 = frame.getStack(stackSize - 2); |
||||
var value2 = frame.getStack(stackSize - 1); |
||||
if (value1.isUnknown() || value2.isUnknown()) { |
||||
continue; |
||||
} |
||||
|
||||
result = evaluateBinaryBranch(insn.getOpcode(), value1.getIntValues(), value2.getIntValues()); |
||||
switch (result) { |
||||
case ALWAYS_TAKEN: |
||||
alwaysTakenBranches.add((JumpInsnNode) insn); |
||||
break; |
||||
case NEVER_TAKEN: |
||||
neverTakenBranches.add((JumpInsnNode) insn); |
||||
break; |
||||
} |
||||
break; |
||||
default: |
||||
if (!InsnNodeUtilsKt.getPure(insn) || InsnNodeUtilsKt.getIntConstant(insn) != null) { |
||||
continue; |
||||
} |
||||
|
||||
if (StackMetadataKt.stackMetadata(insn).getPushes() != 1) { |
||||
continue; |
||||
} |
||||
|
||||
var nextInsn = InsnNodeUtilsKt.getNextReal(insn); |
||||
if (nextInsn == null) { |
||||
continue; |
||||
} |
||||
|
||||
var nextInsnIndex = method.instructions.indexOf(nextInsn); |
||||
var nextFrame = frames[nextInsnIndex]; |
||||
|
||||
value = nextFrame.getStack(nextFrame.getStackSize() - 1); |
||||
if (!value.isSingleConstant()) { |
||||
continue; |
||||
} |
||||
|
||||
constInsns.put(insn, value.getIntValue()); |
||||
break; |
||||
} |
||||
} |
||||
|
||||
for (var insn : alwaysTakenBranches) { |
||||
if (InsnListUtilsKt.replaceSimpleExpression(method.instructions, insn, new JumpInsnNode(Opcodes.GOTO, insn.label))) { |
||||
branchesSimplified++; |
||||
changed = true; |
||||
} |
||||
} |
||||
|
||||
for (var insn : neverTakenBranches) { |
||||
if (InsnListUtilsKt.deleteSimpleExpression(method.instructions, insn)) { |
||||
branchesSimplified++; |
||||
changed = true; |
||||
} |
||||
} |
||||
|
||||
for (var entry : constInsns.entrySet()) { |
||||
var insn = entry.getKey(); |
||||
if (!method.instructions.contains(insn)) { |
||||
continue; |
||||
} |
||||
|
||||
var replacement = InsnNodeUtilsKt.createIntConstant(entry.getValue()); |
||||
if (InsnListUtilsKt.replaceSimpleExpression(method.instructions, insn, replacement)) { |
||||
constantsInlined++; |
||||
changed = true; |
||||
} |
||||
} |
||||
|
||||
return changed; |
||||
} |
||||
|
||||
@Override |
||||
protected void postPass(ClassPath classPath) { |
||||
for (var method : inheritedMethodSets) { |
||||
var args = (Type.getArgumentsAndReturnSizes(method.iterator().next().getDesc()) >> 2) - 1; |
||||
|
||||
var allUnknown = true; |
||||
@SuppressWarnings("unchecked") |
||||
var parameters = (ImmutableSet<Integer>[]) new ImmutableSet<?>[args]; |
||||
for (var i = 0; i < args; i++) { |
||||
var parameter = union(method, i, argValues.get(new ArgRef(method, i))); |
||||
if (parameter != null) { |
||||
allUnknown = false; |
||||
} |
||||
parameters[i] = parameter; |
||||
} |
||||
|
||||
if (allUnknown) { |
||||
constArgs.remove(method); |
||||
} else { |
||||
constArgs.put(method, parameters); |
||||
} |
||||
} |
||||
} |
||||
|
||||
@Override |
||||
protected void postTransform(ClassPath classPath) { |
||||
logger.info("Simplified {} dummy branches and inlined {} constants", branchesSimplified, constantsInlined); |
||||
} |
||||
} |
@ -0,0 +1,403 @@ |
||||
package dev.openrs2.deob.transform |
||||
|
||||
import com.github.michaelbull.logging.InlineLogger |
||||
import com.google.common.collect.HashMultimap |
||||
import com.google.common.collect.ImmutableSet |
||||
import com.google.common.collect.Multimap |
||||
import dev.openrs2.asm.* |
||||
import dev.openrs2.asm.classpath.ClassPath |
||||
import dev.openrs2.asm.classpath.Library |
||||
import dev.openrs2.asm.transform.Transformer |
||||
import dev.openrs2.common.collect.DisjointSet |
||||
import dev.openrs2.deob.ArgRef |
||||
import dev.openrs2.deob.analysis.IntInterpreter |
||||
import dev.openrs2.deob.analysis.SourcedIntValue |
||||
import org.objectweb.asm.Opcodes |
||||
import org.objectweb.asm.Type |
||||
import org.objectweb.asm.tree.* |
||||
import org.objectweb.asm.tree.analysis.Analyzer |
||||
import org.objectweb.asm.tree.analysis.AnalyzerException |
||||
|
||||
class DummyArgTransformer : Transformer() { |
||||
private data class ConditionalCall( |
||||
val conditionVar: Int, |
||||
val conditionOpcode: Int, |
||||
val conditionValue: Int?, |
||||
val method: DisjointSet.Partition<MemberRef>, |
||||
val constArgs: List<Int?> |
||||
) |
||||
|
||||
private enum class BranchResult { |
||||
ALWAYS_TAKEN, NEVER_TAKEN, UNKNOWN; |
||||
|
||||
companion object { |
||||
fun fromTakenNotTaken(taken: Int, notTaken: Int): BranchResult { |
||||
require(taken != 0 || notTaken != 0) |
||||
|
||||
return when { |
||||
taken == 0 -> NEVER_TAKEN |
||||
notTaken == 0 -> ALWAYS_TAKEN |
||||
else -> UNKNOWN |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
private val argValues: Multimap<ArgRef, SourcedIntValue> = HashMultimap.create() |
||||
private val conditionalCalls: Multimap<DisjointSet.Partition<MemberRef>?, ConditionalCall> = HashMultimap.create() |
||||
private val constArgs = mutableMapOf<DisjointSet.Partition<MemberRef>, Array<ImmutableSet<Int>?>>() |
||||
private lateinit var inheritedMethodSets: DisjointSet<MemberRef> |
||||
private var branchesSimplified = 0 |
||||
private var constantsInlined = 0 |
||||
|
||||
private fun isMutuallyRecursiveDummy( |
||||
method: DisjointSet.Partition<MemberRef>, |
||||
arg: Int, |
||||
source: DisjointSet.Partition<MemberRef>, |
||||
value: Int |
||||
): Boolean { |
||||
for (sourceToMethodCall in conditionalCalls[source]) { |
||||
if (sourceToMethodCall.method != method) { |
||||
continue |
||||
} |
||||
|
||||
for (methodToSourceCall in conditionalCalls[method]) { |
||||
if (methodToSourceCall.method != source || methodToSourceCall.conditionVar != arg) { |
||||
continue |
||||
} |
||||
|
||||
var taken = if (methodToSourceCall.conditionValue != null) { |
||||
evaluateBinaryBranch(methodToSourceCall.conditionOpcode, value, methodToSourceCall.conditionValue) |
||||
} else { |
||||
evaluateUnaryBranch(methodToSourceCall.conditionOpcode, value) |
||||
} |
||||
|
||||
if (taken) { |
||||
continue |
||||
} |
||||
|
||||
val constArg = methodToSourceCall.constArgs[sourceToMethodCall.conditionVar]!! |
||||
|
||||
taken = if (sourceToMethodCall.conditionValue != null) { |
||||
evaluateBinaryBranch( |
||||
sourceToMethodCall.conditionOpcode, |
||||
constArg, |
||||
sourceToMethodCall.conditionValue |
||||
) |
||||
} else { |
||||
evaluateUnaryBranch(sourceToMethodCall.conditionOpcode, constArg) |
||||
} |
||||
|
||||
if (taken) { |
||||
continue |
||||
} |
||||
|
||||
return true |
||||
} |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
private fun union( |
||||
method: DisjointSet.Partition<MemberRef>, |
||||
arg: Int, |
||||
intValues: Collection<SourcedIntValue> |
||||
): ImmutableSet<Int>? { |
||||
val builder = ImmutableSet.builder<Int>() |
||||
|
||||
for ((source, intValue) in intValues) { |
||||
if (intValue.isUnknown) { |
||||
return null |
||||
} |
||||
|
||||
if (source == method) { |
||||
continue |
||||
} |
||||
|
||||
if (intValue.isSingleConstant) { |
||||
if (isMutuallyRecursiveDummy(method, arg, source, intValue.intValue)) { |
||||
continue |
||||
} |
||||
} |
||||
|
||||
builder.addAll(intValue.intValues) |
||||
} |
||||
|
||||
val set = builder.build() |
||||
return if (set.isEmpty()) { |
||||
null |
||||
} else { |
||||
set |
||||
} |
||||
} |
||||
|
||||
override fun preTransform(classPath: ClassPath) { |
||||
inheritedMethodSets = classPath.createInheritedMethodSets() |
||||
branchesSimplified = 0 |
||||
constantsInlined = 0 |
||||
} |
||||
|
||||
override fun prePass(classPath: ClassPath) { |
||||
argValues.clear() |
||||
conditionalCalls.clear() |
||||
} |
||||
|
||||
@Throws(AnalyzerException::class) |
||||
override fun transformCode( |
||||
classPath: ClassPath, |
||||
library: Library, |
||||
clazz: ClassNode, |
||||
method: MethodNode |
||||
): Boolean { |
||||
val parentMethod = inheritedMethodSets[MemberRef(clazz, method)]!! |
||||
|
||||
val stores = BooleanArray(method.maxLocals) |
||||
for (insn in method.instructions) { |
||||
if (insn is VarInsnNode && insn.opcode == Opcodes.ISTORE) { |
||||
stores[insn.`var`] = true |
||||
} |
||||
} |
||||
|
||||
CONDITIONAL_CALL_MATCHER.match(method).forEach { match -> |
||||
var matchIndex = 0 |
||||
|
||||
val load = match[matchIndex++] as VarInsnNode |
||||
if (stores[load.`var`]) { |
||||
return@forEach |
||||
} |
||||
|
||||
var callerSlots = Type.getArgumentsAndReturnSizes(method.desc) shr 2 |
||||
if (method.access and Opcodes.ACC_STATIC != 0) { |
||||
callerSlots++ |
||||
} |
||||
if (load.`var` >= callerSlots) { |
||||
return@forEach |
||||
} |
||||
|
||||
val conditionValue: Int? |
||||
var conditionOpcode = match[matchIndex].opcode |
||||
if (conditionOpcode == Opcodes.IFEQ || conditionOpcode == Opcodes.IFNE) { |
||||
conditionValue = null |
||||
matchIndex++ |
||||
} else { |
||||
conditionValue = match[matchIndex++].intConstant |
||||
conditionOpcode = match[matchIndex++].opcode |
||||
} |
||||
|
||||
val invoke = match[match.size - 1] as MethodInsnNode |
||||
val invokeArgTypes = Type.getArgumentTypes(invoke.desc).size |
||||
val constArgs = arrayOfNulls<Int>(invokeArgTypes) |
||||
if (invoke.opcode != Opcodes.INVOKESTATIC) { |
||||
matchIndex++ |
||||
} |
||||
for (i in constArgs.indices) { |
||||
val insn = match[matchIndex++] |
||||
if (insn.opcode == Opcodes.ACONST_NULL) { |
||||
matchIndex++ |
||||
} else { |
||||
constArgs[i] = insn.intConstant |
||||
} |
||||
} |
||||
|
||||
val callee = inheritedMethodSets[MemberRef(invoke)] ?: return@forEach |
||||
conditionalCalls.put( |
||||
parentMethod, |
||||
ConditionalCall(load.`var`, conditionOpcode, conditionValue, callee, constArgs.asList()) |
||||
) |
||||
} |
||||
|
||||
val parameters = constArgs[parentMethod] |
||||
val analyzer = Analyzer(IntInterpreter(parameters)) |
||||
val frames = analyzer.analyze(clazz.name, method) |
||||
|
||||
var changed = false |
||||
val alwaysTakenBranches = mutableListOf<JumpInsnNode>() |
||||
val neverTakenBranches = mutableListOf<JumpInsnNode>() |
||||
val constInsns = mutableMapOf<AbstractInsnNode, Int>() |
||||
|
||||
frame@ for ((i, frame) in frames.withIndex()) { |
||||
if (frame == null) { |
||||
continue |
||||
} |
||||
|
||||
val stackSize = frame.stackSize |
||||
|
||||
val insn = method.instructions[i] |
||||
when (insn.opcode) { |
||||
Opcodes.INVOKEVIRTUAL, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEINTERFACE -> { |
||||
val invoke = insn as MethodInsnNode |
||||
val invokedMethod = inheritedMethodSets[MemberRef(invoke)] ?: continue@frame |
||||
val args = Type.getArgumentTypes(invoke.desc).size |
||||
|
||||
var k = 0 |
||||
for (j in 0 until args) { |
||||
val arg = frame.getStack(stackSize - args + j) |
||||
argValues.put(ArgRef(invokedMethod, k), SourcedIntValue(parentMethod, arg)) |
||||
k += arg.size |
||||
} |
||||
} |
||||
Opcodes.IFEQ, Opcodes.IFNE -> { |
||||
val value = frame.getStack(stackSize - 1) |
||||
if (value.isUnknown) { |
||||
continue@frame |
||||
} |
||||
|
||||
val result = evaluateUnaryBranch(insn.opcode, value.intValues) |
||||
@Suppress("NON_EXHAUSTIVE_WHEN") |
||||
when (result) { |
||||
BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) |
||||
BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) |
||||
} |
||||
} |
||||
Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE, Opcodes.IF_ICMPLT, Opcodes.IF_ICMPGE, Opcodes.IF_ICMPGT, Opcodes.IF_ICMPLE -> { |
||||
val value1 = frame.getStack(stackSize - 2) |
||||
val value2 = frame.getStack(stackSize - 1) |
||||
if (value1.isUnknown || value2.isUnknown) { |
||||
continue@frame |
||||
} |
||||
|
||||
val result = evaluateBinaryBranch(insn.opcode, value1.intValues, value2.intValues) |
||||
@Suppress("NON_EXHAUSTIVE_WHEN") |
||||
when (result) { |
||||
BranchResult.ALWAYS_TAKEN -> alwaysTakenBranches.add(insn as JumpInsnNode) |
||||
BranchResult.NEVER_TAKEN -> neverTakenBranches.add(insn as JumpInsnNode) |
||||
} |
||||
} |
||||
else -> { |
||||
if (!insn.pure || insn.intConstant != null) { |
||||
continue@frame |
||||
} |
||||
|
||||
if (insn.stackMetadata().pushes != 1) { |
||||
continue@frame |
||||
} |
||||
|
||||
val nextInsn = insn.nextReal ?: continue@frame |
||||
val nextInsnIndex = method.instructions.indexOf(nextInsn) |
||||
val nextFrame = frames[nextInsnIndex] |
||||
|
||||
val value = nextFrame.getStack(nextFrame.stackSize - 1) |
||||
if (value.isSingleConstant) { |
||||
constInsns[insn] = value.intValue |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
for (insn in alwaysTakenBranches) { |
||||
if (method.instructions.replaceSimpleExpression(insn, JumpInsnNode(Opcodes.GOTO, insn.label))) { |
||||
branchesSimplified++ |
||||
changed = true |
||||
} |
||||
} |
||||
|
||||
for (insn in neverTakenBranches) { |
||||
if (method.instructions.deleteSimpleExpression(insn)) { |
||||
branchesSimplified++ |
||||
changed = true |
||||
} |
||||
} |
||||
|
||||
for ((insn, value) in constInsns) { |
||||
if (!method.instructions.contains(insn)) { |
||||
continue |
||||
} |
||||
|
||||
val replacement = createIntConstant(value) |
||||
if (method.instructions.replaceSimpleExpression(insn, replacement)) { |
||||
constantsInlined++ |
||||
changed = true |
||||
} |
||||
} |
||||
|
||||
return changed |
||||
} |
||||
|
||||
override fun postPass(classPath: ClassPath) { |
||||
for (method in inheritedMethodSets) { |
||||
val args = (Type.getArgumentsAndReturnSizes(method.first().desc) shr 2) - 1 |
||||
|
||||
var allUnknown = true |
||||
val parameters = arrayOfNulls<ImmutableSet<Int>?>(args) |
||||
|
||||
for (i in 0 until args) { |
||||
val parameter = union(method, i, argValues[ArgRef(method, i)]) |
||||
if (parameter != null) { |
||||
allUnknown = false |
||||
} |
||||
parameters[i] = parameter |
||||
} |
||||
|
||||
if (allUnknown) { |
||||
constArgs.remove(method) |
||||
} else { |
||||
constArgs[method] = parameters |
||||
} |
||||
} |
||||
} |
||||
|
||||
override fun postTransform(classPath: ClassPath) { |
||||
logger.info { "Simplified $branchesSimplified dummy branches and inlined $constantsInlined constants" } |
||||
} |
||||
|
||||
companion object { |
||||
private val logger = InlineLogger() |
||||
private val CONDITIONAL_CALL_MATCHER = |
||||
InsnMatcher.compile("ILOAD (IFEQ | IFNE | (ICONST | BIPUSH | SIPUSH | LDC) (IF_ICMPEQ | IF_ICMPNE | IF_ICMPLT | IF_ICMPGE | IF_ICMPGT | IF_ICMPLE)) ALOAD? (ICONST | FCONST | DCONST | BIPUSH | SIPUSH | LDC | ACONST_NULL CHECKCAST)+ (INVOKEVIRTUAL | INVOKESTATIC | INVOKEINTERFACE)") |
||||
|
||||
private fun evaluateUnaryBranch(opcode: Int, values: Set<Int>): BranchResult { |
||||
require(values.isNotEmpty()) |
||||
|
||||
var taken = 0 |
||||
var notTaken = 0 |
||||
for (v in values) { |
||||
if (evaluateUnaryBranch(opcode, v)) { |
||||
taken++ |
||||
} else { |
||||
notTaken++ |
||||
} |
||||
} |
||||
|
||||
return BranchResult.fromTakenNotTaken(taken, notTaken) |
||||
} |
||||
|
||||
private fun evaluateUnaryBranch(opcode: Int, value: Int): Boolean { |
||||
return when (opcode) { |
||||
Opcodes.IFEQ -> value == 0 |
||||
Opcodes.IFNE -> value != 0 |
||||
else -> throw IllegalArgumentException() |
||||
} |
||||
} |
||||
|
||||
private fun evaluateBinaryBranch(opcode: Int, values1: Set<Int>, values2: Set<Int>): BranchResult { |
||||
require(values1.isNotEmpty() && values2.isNotEmpty()) |
||||
|
||||
var taken = 0 |
||||
var notTaken = 0 |
||||
for (v1 in values1) { |
||||
for (v2 in values2) { |
||||
if (evaluateBinaryBranch(opcode, v1, v2)) { |
||||
taken++ |
||||
} else { |
||||
notTaken++ |
||||
} |
||||
} |
||||
} |
||||
|
||||
return BranchResult.fromTakenNotTaken(taken, notTaken) |
||||
} |
||||
|
||||
private fun evaluateBinaryBranch(opcode: Int, value1: Int, value2: Int): Boolean { |
||||
return when (opcode) { |
||||
Opcodes.IF_ICMPEQ -> value1 == value2 |
||||
Opcodes.IF_ICMPNE -> value1 != value2 |
||||
Opcodes.IF_ICMPLT -> value1 < value2 |
||||
Opcodes.IF_ICMPGE -> value1 >= value2 |
||||
Opcodes.IF_ICMPGT -> value1 > value2 |
||||
Opcodes.IF_ICMPLE -> value1 <= value2 |
||||
else -> throw IllegalArgumentException() |
||||
} |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue