package dev.openrs2.util.collect; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Queue; public final class ForestDisjointSet implements DisjointSet { private static final class Node implements Partition { private final List> children = new ArrayList<>(); private final T value; private Node parent = this; private int rank = 0; private Node(T value) { this.value = value; } private void setParent(Node parent) { this.parent = parent; this.parent.children.add(this); } private Node find() { if (parent != this) { parent = parent.find(); } return parent; } @Override public Iterator iterator() { return new NodeIterator<>(find()); } @SuppressWarnings("unchecked") @Override public boolean equals(Object other) { if (other == null || getClass() != other.getClass()) { return false; } var node = (Node) other; return find() == node.find(); } @Override public int hashCode() { return find().value.hashCode(); } } private static class NodeIterator implements Iterator { private final Queue> queue = new ArrayDeque<>(); public NodeIterator(Node root) { this.queue.add(root); } @Override public boolean hasNext() { return !queue.isEmpty(); } @Override public T next() { var node = queue.poll(); if (node == null) { throw new NoSuchElementException(); } queue.addAll(node.children); return node.value; } } private static class SetIterator implements Iterator> { private final Iterator> it; public SetIterator(ForestDisjointSet set) { this.it = new HashSet<>(set.nodes.values()).iterator(); } @Override public boolean hasNext() { return it.hasNext(); } @Override public Partition next() { return it.next(); } } private final Map> nodes = new HashMap<>(); private int elements = 0, partitions = 0; @Override public Partition add(T x) { var node = nodes.get(x); if (node != null) { return node.find(); } elements++; partitions++; nodes.put(x, node = new Node<>(x)); return node; } @Override public Partition get(T x) { return get0(x); } private Node get0(T x) { var node = nodes.get(x); if (node == null) { return null; } return node.find(); } @Override public void union(Partition x, Partition y) { var xRoot = ((Node) x).find(); var yRoot = ((Node) y).find(); if (xRoot == yRoot) { return; } if (xRoot.rank < yRoot.rank) { xRoot.setParent(yRoot); } else if (xRoot.rank > yRoot.rank) { yRoot.setParent(xRoot); } else { yRoot.setParent(xRoot); xRoot.rank++; } partitions--; } @Override public int elements() { return elements; } @Override public int partitions() { return partitions; } @Override public Iterator> iterator() { return new SetIterator<>(this); } }