diff --git a/deob/src/main/java/dev/openrs2/deob/transform/StaticScramblingTransformer.kt b/deob/src/main/java/dev/openrs2/deob/transform/StaticScramblingTransformer.kt index fc0b7d0a1a..3295cd51aa 100644 --- a/deob/src/main/java/dev/openrs2/deob/transform/StaticScramblingTransformer.kt +++ b/deob/src/main/java/dev/openrs2/deob/transform/StaticScramblingTransformer.kt @@ -4,26 +4,53 @@ import dev.openrs2.asm.ClassVersionUtils import dev.openrs2.asm.MemberRef import dev.openrs2.asm.classpath.ClassPath import dev.openrs2.asm.classpath.Library +import dev.openrs2.asm.getExpression +import dev.openrs2.asm.sequential import dev.openrs2.asm.transform.Transformer import dev.openrs2.deob.remap.TypedRemapper import org.objectweb.asm.Opcodes +import org.objectweb.asm.tree.AbstractInsnNode import org.objectweb.asm.tree.ClassNode import org.objectweb.asm.tree.FieldInsnNode +import org.objectweb.asm.tree.FieldNode +import org.objectweb.asm.tree.InsnList +import org.objectweb.asm.tree.InsnNode import org.objectweb.asm.tree.MethodInsnNode import org.objectweb.asm.tree.MethodNode +import kotlin.math.max class StaticScramblingTransformer : Transformer() { - private val fields = mutableMapOf() - private val methods = mutableMapOf() + private data class Field(val node: FieldNode, val initializer: InsnList, val version: Int, val maxStack: Int) { + val dependencies = initializer.asSequence() + .filterIsInstance() + .filter { it.opcode == Opcodes.GETSTATIC } + .map(::MemberRef) + .toSet() + } + + private val fields = mutableMapOf() + private val fieldClasses = mutableMapOf() + private val methodClasses = mutableMapOf() private var nextStaticClass: ClassNode? = null + private var nextClinit: MethodNode? = null private val staticClasses = mutableListOf() - private fun nextClass(): ClassNode { + private fun nextClass(): Pair { var clazz = nextStaticClass if (clazz != null && (clazz.fields.size + clazz.methods.size) < MAX_FIELDS_AND_METHODS) { - return clazz + return Pair(clazz, nextClinit!!) } + val clinit = MethodNode() + clinit.access = Opcodes.ACC_STATIC + clinit.name = "" + clinit.desc = "()V" + clinit.exceptions = mutableListOf() + clinit.parameters = mutableListOf() + clinit.instructions = InsnList() + clinit.instructions.add(InsnNode(Opcodes.RETURN)) + clinit.tryCatchBlocks = mutableListOf() + clazz = ClassNode() clazz.version = Opcodes.V1_1 clazz.access = Opcodes.ACC_PUBLIC or Opcodes.ACC_SUPER @@ -32,17 +59,72 @@ class StaticScramblingTransformer : Transformer() { clazz.interfaces = mutableListOf() clazz.innerClasses = mutableListOf() clazz.fields = mutableListOf() - clazz.methods = mutableListOf() + clazz.methods = mutableListOf(clinit) staticClasses += clazz nextStaticClass = clazz + nextClinit = clinit + + return Pair(clazz, clinit) + } + + private fun extractInitializers(clazz: ClassNode, clinit: MethodNode, block: List) { + val putstatics = block.filterIsInstance() + .filter { it.opcode == Opcodes.PUTSTATIC } + + for (putstatic in putstatics) { + if (putstatic.owner != clazz.name || putstatic.name in TypedRemapper.EXCLUDED_FIELDS) { + continue + } + + val node = clazz.fields.find { it.name == putstatic.name && it.desc == putstatic.desc } ?: continue + // TODO(gpe): use a filter here (pure with no *LOADs?) + val expr = getExpression(putstatic) ?: continue + + val initializer = InsnList() + for (insn in expr) { + clinit.instructions.remove(insn) + initializer.add(insn) + } + clinit.instructions.remove(putstatic) + initializer.add(putstatic) + + clazz.fields.remove(node) + + val ref = MemberRef(putstatic) + fields[ref] = Field(node, initializer, clazz.version, clinit.maxStack) + } + } - return clazz + private fun spliceInitializers() { + val done = mutableSetOf() + for ((ref, field) in fields) { + spliceInitializers(done, ref, field) + } + } + + private fun spliceInitializers(done: MutableSet, ref: MemberRef, field: Field) { + if (!done.add(ref)) { + return + } + + for (dependency in field.dependencies) { + spliceInitializers(done, dependency, fields[dependency]!!) + } + + val (clazz, clinit) = nextClass() + clazz.fields.add(field.node) + clazz.version = ClassVersionUtils.maxVersion(clazz.version, field.version) + clinit.instructions.insertBefore(clinit.instructions.last, field.initializer) + clinit.maxStack = max(clinit.maxStack, field.maxStack) + + fieldClasses[ref] = clazz.name } override fun preTransform(classPath: ClassPath) { fields.clear() - methods.clear() + fieldClasses.clear() + methodClasses.clear() nextStaticClass = null staticClasses.clear() @@ -58,6 +140,31 @@ class StaticScramblingTransformer : Transformer() { continue } + val clinit = clazz.methods.find { it.name == "" } + if (clinit != null) { + val insns = clinit.instructions.toMutableList() + + /* + * Most (or all?) of the methods have "simple" + * initializers that we're capable of moving in the first + * and last basic blocks of the method. The last basic + * block is always at the end of the code and ends in a + * RETURN. This allows us to avoid worrying about making a + * full basic block control flow graph here. + */ + + val entry = insns.takeWhile { it.sequential } + extractInitializers(clazz, clinit, entry) + + val last = insns.lastOrNull() + if (last != null && last.opcode == Opcodes.RETURN) { + insns.removeAt(insns.size - 1) + + val exit = insns.takeLastWhile { it.sequential } + extractInitializers(clazz, clinit, exit) + } + } + clazz.methods.removeIf { method -> if (method.access and Opcodes.ACC_STATIC == 0) { return@removeIf false @@ -67,15 +174,17 @@ class StaticScramblingTransformer : Transformer() { return@removeIf false } - val staticClass = nextClass() + val (staticClass, _) = nextClass() staticClass.methods.add(method) staticClass.version = ClassVersionUtils.maxVersion(staticClass.version, clazz.version) - methods[MemberRef(clazz, method)] = staticClass.name + methodClasses[MemberRef(clazz, method)] = staticClass.name return@removeIf true } } + spliceInitializers() + for (clazz in staticClasses) { library.add(clazz) } @@ -85,8 +194,8 @@ class StaticScramblingTransformer : Transformer() { override fun transformCode(classPath: ClassPath, library: Library, clazz: ClassNode, method: MethodNode): Boolean { for (insn in method.instructions) { when (insn) { - is FieldInsnNode -> insn.owner = fields.getOrDefault(MemberRef(insn), insn.owner) - is MethodInsnNode -> insn.owner = methods.getOrDefault(MemberRef(insn), insn.owner) + is FieldInsnNode -> insn.owner = fieldClasses.getOrDefault(MemberRef(insn), insn.owner) + is MethodInsnNode -> insn.owner = methodClasses.getOrDefault(MemberRef(insn), insn.owner) } }