diff --git a/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/GlConstantTransformer.kt b/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/GlConstantTransformer.kt index 2f3e55d5..371fe8b2 100644 --- a/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/GlConstantTransformer.kt +++ b/deob-ast/src/main/java/dev/openrs2/deob/ast/transform/GlConstantTransformer.kt @@ -60,14 +60,13 @@ class GlConstantTransformer : Transformer() { val glUnit = units[GL_CLASS] ?: return val glInterface = glUnit.primaryType.get() - // remove existing declarations first to maintain sort order - for (enum in enums) { - val declaration = glInterface.getFieldByName(enum.name) - declaration.ifPresent { it.remove() } - } + // add missing fields + val fields = enums.filter { !glInterface.getFieldByName(it.name).isPresent } + .map { it.toDeclaration() } + glInterface.members.addAll(fields) - val fields = enums.sortedBy(GlEnum::value).map { it.toDeclaration() } - glInterface.members.addAll(0, fields) + // sort fields by value for consistency + glInterface.members.sortWith(FIELD_METHOD_COMPARATOR.thenComparing(GL_FIELD_VALUE_COMPARATOR)) } private fun transformCall(unit: CompilationUnit, expr: MethodCallExpr) { @@ -203,5 +202,39 @@ class GlConstantTransformer : Transformer() { private const val GL_CLASS = "javax.media.opengl.$GL_CLASS_UNQUALIFIED" private val REGISTRY = GlRegistry.parse() private val VENDORS = setOf("ARB", "EXT") + + private val FIELD_METHOD_COMPARATOR = Comparator> { a, b -> + when { + a.isFieldDeclaration && !b.isFieldDeclaration -> -1 + !a.isFieldDeclaration && b.isFieldDeclaration -> 1 + else -> 0 + } + } + + private fun BodyDeclaration<*>.getIntValue(): Int? { + if (!isFieldDeclaration) { + return null + } + + val variable = asFieldDeclaration().variables.firstOrNull() ?: return null + return variable.initializer.map { + if (it.isIntegerLiteralExpr) { + it.asIntegerLiteralExpr().checkedAsInt() + } else { + null + } + }.orElse(null) + } + + private val GL_FIELD_VALUE_COMPARATOR = Comparator> { a, b -> + val aValue = a.getIntValue() + val bValue = b.getIntValue() + when { + aValue != null && bValue != null -> aValue - bValue + aValue != null && bValue == null -> -1 + aValue == null && bValue != null -> 1 + else -> 0 + } + } } }