package dev.openrs2.util.collect import java.util.ArrayDeque class ForestDisjointSet : DisjointSet { private class Node(val value: T) : DisjointSet.Partition { val children = mutableListOf>() private var _parent = this var parent get() = _parent set(parent) { _parent = parent _parent.children.add(this) } var rank = 0 fun find(): Node { if (parent !== this) { _parent = parent.find() } return parent } override fun iterator(): Iterator { return NodeIterator(find()) } override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is Node<*>) return false return find() === other.find() } override fun hashCode(): Int { return find().value.hashCode() } override fun toString(): String { return find().value.toString() } } private class NodeIterator(root: Node) : Iterator { private val queue = ArrayDeque>() init { queue.add(root) } override fun hasNext(): Boolean { return queue.isNotEmpty() } override fun next(): T { val node = queue.poll() ?: throw NoSuchElementException() queue.addAll(node.children) return node.value } } private val nodes = mutableMapOf>() override val elements get() = nodes.size override var partitions = 0 private set override fun add(x: T): DisjointSet.Partition { val node = findNode(x) if (node != null) { return node } partitions++ val newNode = Node(x) nodes[x] = newNode return newNode } override fun get(x: T): DisjointSet.Partition? { return findNode(x) } private fun findNode(x: T): Node? { val node = nodes[x] ?: return null return node.find() } override fun union(x: DisjointSet.Partition, y: DisjointSet.Partition) { require(x is Node) require(y is Node) val xRoot = x.find() val yRoot = y.find() if (xRoot == yRoot) { return } when { xRoot.rank < yRoot.rank -> { xRoot.parent = yRoot } xRoot.rank > yRoot.rank -> { yRoot.parent = xRoot } else -> { yRoot.parent = xRoot xRoot.rank++ } } partitions-- } override fun iterator(): Iterator> { return nodes.values.iterator() } }