import torch
import numpy as np
from copy import deepcopy


### Disjoint set union structure to maintain cluster structure of a graph
class DSU:
    def __init__(self, n_vertices):
        self.parent = np.arange(n_vertices)
        self.rank = np.zeros(n_vertices)

    def find(self, v):
        if self.parent[v] == v:
            return v
        self.parent[v] = self.find(self.parent[v])
        return self.parent[v]

    def unite(self, u, v):
        u_root = self.find(u)
        v_root = self.find(v)
        if self.rank[u_root] < self.rank[v_root]:
            u_root, v_root = v_root, u_root
        if self.rank[u_root] == self.rank[v_root]:
            self.rank[u_root] += 1
        self.parent[v_root] = u_root
        
### Prim's minimal spanning tree algorithm

def prim_algo(adjacency_matrix):
    n = len(adjacency_matrix)
    
    infty = torch.max(adjacency_matrix).item() + 10
    dst = torch.ones(n, device=adjacency_matrix.device) * infty
    ancestors = -torch.ones(n, dtype=int, device=adjacency_matrix.device)
    visited = torch.zeros(n, dtype=bool, device=adjacency_matrix.device)
    
    mst_edges = np.zeros((n - 1, 2), dtype=np.int32)
    s, v = torch.tensor(0.0, device=adjacency_matrix.device), 0
    for i in range(n - 1):
        visited[v] = 1
        
        ancestors[dst > adjacency_matrix[v]] = v
        dst = torch.minimum(dst, adjacency_matrix[v])
        dst[visited] = infty
        v = torch.argmin(dst)

        s += adjacency_matrix[v][ancestors[v]]
        
        mst_edges[i][0] = v
        mst_edges[i][1] = ancestors[v]
                
    edge_weights = adjacency_matrix[mst_edges[:, 0], mst_edges[:, 1] ].cpu()
    return s, mst_edges, edge_weights

### RTD_Lite (version3)

### Main part
class RTD_Lite_TSP:
    def __init__(self, r1, r2, quant_outer=None, quant_inner=None, distance='euclidean', cache_r2_min = False):
        self.r1 = r1
        self.r2 = r2
        
        self.device = r2.device
        self.cache_r2_min = cache_r2_min
        self.cache = None
        
    def __call__(self):
        rmin = torch.minimum(self.r1, self.r2)

        #_, r1_edge_idx, r1_edge_w = prim_algo(self.r1.cpu())
        r1_edge_idx = self.r1_edge_idx
        r1_edge_w = self.r1_edge_w

        #r1_mst = set()
        #for i in range(r1_edge_idx.shape[0]):
        #    r1_mst.add((r1_edge_idx[i][0], r1_edge_idx[i][1]))
        #    r1_mst.add((r1_edge_idx[i][1], r1_edge_idx[i][0]))

        #print(r1_edge_idx)
        #print(r1_edge_w)

        r1_edge_idx = r1_edge_idx[r1_edge_w.argsort()]
        r1_edge_w = r1_edge_w[r1_edge_w.argsort()]
        
        if not self.cache_r2_min or (self.cache is None):
            _, rmin_edge_idx, rmin_edge_w = prim_algo(rmin.cpu())
            rmin_edge_idx = rmin_edge_idx[rmin_edge_w.argsort()]
            rmin_edge_w = rmin_edge_w[rmin_edge_w.argsort()]

        if self.cache_r2_min:
            if self.cache is None:
                self.cache = (rmin_edge_idx, rmin_edge_w)
            else:
                rmin_edge_idx, rmin_edge_w = self.cache
        #
        # main loop
        #
        min_graph_dsu = DSU(self.r1.shape[0])       
        barcodes = {'1->2' : [], '2->1' : []}
        barcodes_idx = {'1->2' : [], '2->1' : []}

        for i in range(len(rmin_edge_idx)):

            #if (rmin_edge_idx[i][0], rmin_edge_idx[i][1]) in r1_mst:
            #    min_graph_dsu.unite(rmin_edge_idx[i][0], rmin_edge_idx[i][1])
            #    print('mst skipping')
            #    # edge is MST, zero barcode
            #    continue

            u_clique = min_graph_dsu.find(rmin_edge_idx[i][0])
            v_clique = min_graph_dsu.find(rmin_edge_idx[i][1])
            birth = rmin_edge_w[i]
            birth_idx = rmin_edge_idx[i]
            
            r1_graph_dsu = deepcopy(min_graph_dsu)
            for j in range(len(r1_edge_idx)):
                r1_graph_dsu.unite(r1_edge_idx[j][0], r1_edge_idx[j][1])    
                if r1_graph_dsu.find(u_clique) == r1_graph_dsu.find(v_clique):
                    death_1 = r1_edge_w[j]
                    death_1_idx = r1_edge_idx[j]
                    #print(birth, death_1)
                    break
            
            if death_1 > birth:
                barcodes['1->2'].append(torch.stack((birth, death_1)))
                barcodes_idx['1->2'].append((birth_idx, death_1_idx))

            min_graph_dsu.unite(rmin_edge_idx[i][0], rmin_edge_idx[i][1])
       
        if len(barcodes['1->2']) > 0:
            barcodes['1->2'] = torch.stack(barcodes['1->2']).to(self.device)

        return barcodes, barcodes_idx
    
