diff --git a/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/IfElseTransformer.kt b/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/IfElseTransformer.kt index 73db9f04..a490ce90 100644 --- a/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/IfElseTransformer.kt +++ b/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/IfElseTransformer.kt @@ -2,6 +2,7 @@ package dev.openrs2.deob.ast.transform import com.github.javaparser.ast.CompilationUnit import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.MethodDeclaration import com.github.javaparser.ast.expr.BinaryExpr import com.github.javaparser.ast.expr.ConditionalExpr import com.github.javaparser.ast.stmt.BlockStmt @@ -9,6 +10,7 @@ import com.github.javaparser.ast.stmt.IfStmt import com.github.javaparser.ast.stmt.ReturnStmt import com.github.javaparser.ast.stmt.Statement import com.github.javaparser.ast.stmt.ThrowStmt +import com.github.javaparser.ast.type.VoidType import dev.openrs2.deob.ast.Library import dev.openrs2.deob.ast.LibraryGroup import dev.openrs2.deob.ast.util.countNots @@ -122,7 +124,7 @@ class IfElseTransformer : Transformer() { * } else { * if (!a) { * ... - * throw ...; + * throw ...; // or return * } * ... * } @@ -133,7 +135,7 @@ class IfElseTransformer : Transformer() { * ... * } else { * ... - * throw ...; + * throw ...; // or return * } */ unit.walk { stmt: IfStmt -> @@ -238,6 +240,199 @@ class IfElseTransformer : Transformer() { outerStmt.condition = BinaryExpr(outerStmt.condition, innerStmt.condition, BinaryExpr.Operator.AND) outerStmt.thenStmt = innerStmt.thenStmt } + + unit.walk { method: MethodDeclaration -> + if (method.type !is VoidType) { + return@walk + } + + method.body.ifPresent { body -> + val ifStmt = body.statements.lastOrNull() ?: return@ifPresent + if (ifStmt !is IfStmt) { + return@ifPresent + } + + val thenStatements = ifStmt.thenStmt.findAll(Statement::class.java).size + ifStmt.elseStmt.ifPresentOrElse({ elseStmt -> + if (elseStmt.isIf()) { + return@ifPresentOrElse + } + + val elseStatements = elseStmt.findAll(Statement::class.java).size + if (thenStatements <= IF_DEINDENT_THRESHOLD && elseStatements <= IF_DEINDENT_THRESHOLD) { + return@ifPresentOrElse + } + + /* + * Rewrite: + * + * void m(...) { + * ... + * if (a) { + * ... + * } else { + * ... + * } + * } + * + * to: + * + * void m(...) { + * ... + * if (!a) { // or `if (a)`, depending on which arm is smaller + * ... + * return; + * } + * ... + * } + */ + if (elseStatements > thenStatements) { + body.statements.addAll(elseStmt.flatten()) + + ifStmt.thenStmt = ifStmt.thenStmt.appendReturn() + ifStmt.removeElseStmt() + } else { + body.statements.addAll(ifStmt.thenStmt.flatten()) + + ifStmt.condition = ifStmt.condition.not() + ifStmt.thenStmt = elseStmt.appendReturn() + ifStmt.removeElseStmt() + } + }, { + /* + * Rewrite: + * + * void m(...) { + * ... + * if (a) { + * ... + * } + * } + * + * to: + * + * void m(...) { + * ... + * if (!a) { + * return; + * } + * ... + * } + */ + if (thenStatements <= IF_DEINDENT_THRESHOLD) { + return@ifPresentOrElse + } + + body.statements.addAll(ifStmt.thenStmt.flatten()) + + ifStmt.condition = ifStmt.condition.not() + ifStmt.thenStmt = BlockStmt(NodeList(ReturnStmt())) + }) + } + } + + /* + * Rewrite: + * + * if (a) { + * ... + * throw ...; // or return + * } else { + * ... + * } + * + * to: + * + * if (a) { // or `if (!a)`, if the arms are swapped + * ... + * throw ...; // or return + * } + * ... + */ + unit.walk { blockStmt: BlockStmt -> + /* + * XXX(gpe): need to iterate through blockStmt.stmts manually as we + * insert extra statements during iteration (ugh!) + */ + var index = 0 + while (index < blockStmt.statements.size) { + val ifStmt = blockStmt.statements[index] + if (ifStmt !is IfStmt) { + index++ + continue + } + + ifStmt.elseStmt.ifPresent { elseStmt -> + if (elseStmt.isIf()) { + return@ifPresent + } + + /* + * If one of the arms consists of just a throw, move that + * into an if regardless of the fact that the method as a + * whole will end up longer. + */ + if (ifStmt.thenStmt.isThrow()) { + blockStmt.statements.addAll(index + 1, elseStmt.flatten()) + + ifStmt.removeElseStmt() + + return@ifPresent + } else if (elseStmt.isThrow()) { + blockStmt.statements.addAll(index + 1, ifStmt.thenStmt.flatten()) + + ifStmt.condition = ifStmt.condition.not() + ifStmt.thenStmt = elseStmt.appendReturn() + ifStmt.removeElseStmt() + + return@ifPresent + } + + val thenStatements = ifStmt.thenStmt.findAll(Statement::class.java).size + val elseStatements = elseStmt.findAll(Statement::class.java).size + if (thenStatements <= IF_DEINDENT_THRESHOLD && elseStatements <= IF_DEINDENT_THRESHOLD) { + return@ifPresent + } + + if (elseStatements > thenStatements && ifStmt.thenStmt.isTailThrowOrReturn()) { + blockStmt.statements.addAll(index + 1, elseStmt.flatten()) + + ifStmt.removeElseStmt() + } else if (elseStmt.isTailThrowOrReturn()) { + blockStmt.statements.addAll(index + 1, ifStmt.thenStmt.flatten()) + + ifStmt.condition = ifStmt.condition.not() + ifStmt.thenStmt = elseStmt.appendReturn() + ifStmt.removeElseStmt() + } + } + + index++ + } + } + } + + private fun Statement.appendReturn(): Statement { + return if (this is BlockStmt) { + val last = statements.lastOrNull() + if (last is ReturnStmt || last is ThrowStmt) { + clone() + } else { + BlockStmt(NodeList(statements.map(Statement::clone).plus(ReturnStmt()))) + } + } else if (this is ReturnStmt || this is ThrowStmt) { + clone() + } else { + BlockStmt(NodeList(clone(), ReturnStmt())) + } + } + + private fun Statement.flatten(): Collection { + return if (this is BlockStmt) { + statements.map(Statement::clone) + } else { + listOf(clone()) + } } private fun Statement.isIf(): Boolean { @@ -259,6 +454,14 @@ class IfElseTransformer : Transformer() { } } + private fun Statement.isThrow(): Boolean { + return when (this) { + is ThrowStmt -> true + is BlockStmt -> statements.singleOrNull() is ThrowStmt + else -> false + } + } + private fun Statement.isTailThrowOrReturn(): Boolean { return when (this) { is ThrowStmt, is ReturnStmt -> true @@ -269,4 +472,8 @@ class IfElseTransformer : Transformer() { else -> false } } + + private companion object { + private const val IF_DEINDENT_THRESHOLD = 5 + } }