diff --git a/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/BitMaskTransformer.kt b/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/BitMaskTransformer.kt index 48632086..301ddfd1 100644 --- a/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/BitMaskTransformer.kt +++ b/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/BitMaskTransformer.kt @@ -15,12 +15,29 @@ import javax.inject.Singleton @Singleton class BitMaskTransformer : Transformer() { override fun transformUnit(group: LibraryGroup, library: Library, unit: CompilationUnit) { + /* + * Transform: + * + * (x & y) >> z + * + * to: + * + * (x >> z) & (y >> z) + * + * For example: + * + * (x & 0xFF00) >> 8 + * + * to: + * + * (x >> 8) & 0xFF + */ unit.walk { expr: BinaryExpr -> val shiftOp = expr.operator val bitwiseExpr = expr.left val shamtExpr = expr.right - if (shiftOp !in SHIFT_OPS || bitwiseExpr !is BinaryExpr || shamtExpr !is IntegerLiteralExpr) { + if (shiftOp !in RIGHT_SHIFT_OPS || bitwiseExpr !is BinaryExpr || shamtExpr !is IntegerLiteralExpr) { return@walk } @@ -61,10 +78,70 @@ class BitMaskTransformer : Transformer() { expr.replace(BinaryExpr(BinaryExpr(argExpr.clone(), shamtExpr.clone(), shiftOp), maskExpr, bitwiseOp)) } + + /* + * Transform: + * + * (x << y) & z + * + * to: + * + * (x & (z >>> y)) << y + * + * For example: + * + * (x << 8) & 0xFF00 + * + * to: + * + * (x & 0xFF) << 8 + */ + unit.walk { expr: BinaryExpr -> + val bitwiseOp = expr.operator + val shiftExpr = expr.left + val maskExpr = expr.right + + if (bitwiseOp !in BITWISE_OPS || shiftExpr !is BinaryExpr) { + return@walk + } + + val shiftOp = shiftExpr.operator + val argExpr = shiftExpr.left + val shamtExpr = shiftExpr.right + + if (shiftOp != BinaryExpr.Operator.LEFT_SHIFT || shamtExpr !is IntegerLiteralExpr) { + return@walk + } + + val shamt = shamtExpr.checkedAsInt() + val newMaskExpr = when (maskExpr) { + is IntegerLiteralExpr -> { + var mask = maskExpr.checkedAsInt() + if (shamt > Integer.numberOfTrailingZeros(mask)) { + return@walk + } + + mask = mask ushr shamt + IntegerLiteralExpr(mask.toString()) + } + is LongLiteralExpr -> { + var mask = maskExpr.checkedAsLong() + if (shamt > java.lang.Long.numberOfTrailingZeros(mask)) { + return@walk + } + + mask = mask ushr shamt + mask.toLongLiteralExpr() + } + else -> return@walk + } + + expr.replace(BinaryExpr(BinaryExpr(argExpr.clone(), newMaskExpr, bitwiseOp), shamtExpr.clone(), shiftOp)) + } } private companion object { - private val SHIFT_OPS = setOf( + private val RIGHT_SHIFT_OPS = setOf( BinaryExpr.Operator.SIGNED_RIGHT_SHIFT, BinaryExpr.Operator.UNSIGNED_RIGHT_SHIFT )