/**
 * An implementation of Kruskal's algorithm to find minimum (or maximum, in this case) spanning trees of a weighted undirected graph.
 * If the graph is disconnected, it will return a forest of MSTs for each connected component.
 * Note: we could do this in the pipeline, but this is much more flexible (e.g. if we only want to display a subset of clusters)
 *
 * @param similarityMatrix the (weighted) adjacency matrix of the graph.
 *  - Note that this matrix is symmetric by definition, and all diagonals are 1, so only the lower triangle (minus the diagonal) is stored.
 * @param idToIndex a map associating cluster ids to matrix indices
 * @returns a list of edges (pairs of cluster ids)
 */
export function generateSpanningTree(
  similarityMatrix: number[][],
  idToIndex: Record<number, number>,
): string[][] {
  const ixToId: Record<number, string> = {};
  for (let [id, index] of Object.entries(idToIndex)) {
    ixToId[index] = id;
  }
  const edges: [number, number][] = [];
  const sets: Record<number, DisjointSet> = {};
  const indexPairs: [number, number][] = [];
  for (let i = 0; i < similarityMatrix.length; i++) {
    for (let j = 0; j !== i && j < similarityMatrix[i].length; j++) {
      if (similarityMatrix[i][j] > 0) {
        indexPairs.push([i, j]);
      }
    }
  }
  indexPairs.sort(([a0, a1], [b0, b1]) => {
    return similarityMatrix[b0][b1] - similarityMatrix[a0][a1];
  });
  for (let ix of Object.values(idToIndex)) {
    sets[ix] = new DisjointSet(ix);
  }
  for (let [left, right] of indexPairs) {
    // if (shortestPathLength(treeMatrix, left, right) > ) {}
    if (!DisjointSet.find(sets[left]).equals(DisjointSet.find(sets[right]))) {
      edges.push([left, right]);
      DisjointSet.union(DisjointSet.find(sets[left]), DisjointSet.find(sets[right]));
    }
  }
  return edges.map(([left, right]) => [ixToId[left], ixToId[right]]);
}

/**
 * An implementation of the disjoint-set data structure used by Kruskal's algorithm.
 */
class DisjointSet {
  self: number;
  size: number;
  parent: DisjointSet;

  constructor(self: number) {
    this.self = self;
    this.size = 1;
    this.parent = this;
  }

  equals(other: DisjointSet): boolean {
    return other.self === this.self;
  }

  static find(node: DisjointSet): DisjointSet {
    if (!node.parent.equals(node)) {
      node.parent = DisjointSet.find(node.parent);
      return node.parent;
    } else {
      return node;
    }
  }

  static union(x: DisjointSet, y: DisjointSet) {
    x = DisjointSet.find(x);
    y = DisjointSet.find(y);
    if (x.equals(y)) {
      return;
    }
    if (x.size < y.size) {
      [x, y] = [y, x];
    }
    y.parent = x;
    x.size += y.size;
  }
}
