"""
Going to track (nodes, edges, ()partial_nodes, partial_edges), (frag_nodes, frag_edges), junction_nodes, atom2frag_edges

= set partial structure to original structure
= get rings and add to frag_nodes
= check nodes in rings for degree >= 3 (either this node connects only rings, connects a ring and single path, or is  a multiple junction)
  - if true
    ~ check non-ring nbrs on a different ring
      + no: add junction node and connect via edge
=  add edges between rings
= remove non-junctin ring nodes and ring edges from partial graph and mark junction nodes
= get partial nodes with degree >= 3 
  - get shortest of all shortest paths between junction nodes and new junction nodes
  - if path is just direct connection add node and edge to frg_node, frag_edge
  - generate node for path and edges between for frag_node, frag_edge
= get all nodes with degree == 1
  - get shortest of all shortest paths between junction nodes and new junction nodes
  - if path is just direct connection add node and edge to frg_node, frag_edge
  - generate node for path and edges between for frag_node, frag_edge
        partial_edges = remove_similar_edges(partial_edges, ring_edges)

        # check nodes in rings for degree >= 3
        for atom in ring:
            # remove atoms from partial atoms
            partial_nodes = [node for node in partial_nodes if node != atom]
            if graph.mol.GetAtomWithIdx(atom).GetDegree() >= 3:
                if atom not in junctions:
                    junctions.append(atom)
"""
from rdkit import Chem
from rdkit.Chem import AllChem

import networkx as nx

import torch

from torch_geometric.datasets import QM9
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_undirected

from .bipartite_pair import BipartitePair

class Fragmentation(BaseTransform):
  def __init__(self, vocab_size=30, max_ring=15) -> None:
        self.max_ring = max_ring
        assert (vocab_size > max_ring)
        self.max_path = vocab_size - max_ring

  def __call__(self, graph):
    # set partial structure to original structure
    partial_nodes = [i for i in range(graph.x.shape[0])]
    frag_nodes = []
    junctions = []
    singletons = []

    atom2frag_edges = []
    partial_edges = to_undirected(graph.edge_index.clone())
    frag2frag_edges = []
    junc2frag = [] # to handle junctions on rings and all other junctions
    sing2frag = [] # to handle singletons

    nx_graph = nx.Graph()
    for i in range(graph.x.shape[0]):
        nx_graph.add_node(i)

    """ RINGS """
    # get rings and add to frag_nodes
    rings = AllChem.GetSymmSSSR(graph.mol)
    frag_nodes = [ring_idx for ring_idx, _ in enumerate(rings)] # all rings become a fragment
    atom2frag_edges = [[atom, ring_idx] for ring_idx, ring in enumerate(rings) for atom in ring] # all atoms in a ring are connected to the ring fragment

    for ring_idx,ring in enumerate(rings):
        # remove ring edges from partial_edges
        ring_edges = torch.tensor([(ring[i-1], ring[i]) for i in range(len(ring))] + [(ring[i], ring[i-1]) for i in range(len(ring))]).T
        partial_edges = remove_similar_edges(partial_edges, ring_edges)

        for atom in ring:
            # if atom shares more than the single ring
            if graph.mol.GetAtomWithIdx(atom).GetDegree() >= 3:
                non_ring_nbrs = [nbr for nbr in graph.mol.GetAtomWithIdx(atom).GetNeighbors() if nbr.GetIdx() not in ring]
                # if a single non-ring neighbors then we need only add it to the junctions
                if len(non_ring_nbrs) == 1:
                    if atom not in junctions:
                        junctions.append(atom)
                        junc2frag.append([atom, ring_idx])
                # if multiple non-ring neighbors then we need to add it to the fragments
                elif len(non_ring_nbrs) > 1:
                    if atom not in junctions:
                        junctions.append(atom)
                        junc2frag.append([atom, ring_idx]) # A ring could have multiple junctions
                        frag_nodes.append(len(frag_nodes))
                        atom2frag_edges.append([atom, frag_nodes[-1]])
                # if only rings then we just add edges on all the rings and remove the atom from partial_nodes
                else:
                    shared_rings = [ring_id for ring_id, ring in enumerate(rings) if atom in ring]
                    frag2frag_edges += [[ring_id, target] for target in shared_rings if target != ring_id]
                    partial_nodes = [node for node in partial_nodes if node != atom]
            # atom is only in ring and we can remove it
            else:
                partial_nodes = [node for node in partial_nodes if node != atom]


    # connect fragments that share a junction [note: frag2frag is currently a directed graph]
    for junction in junctions:
        junction_targets = [target for source, target in atom2frag_edges if source == junction]
        for i in range(len(junction_targets)):
            frag2frag_edges += [[junction_targets[i], target] for target in junction_targets[i+1:]]

    # remove duplicates [note: directed graph]
    frag2frag_edges = list(set(map(tuple, frag2frag_edges))) # remove duplicates


    """ JUNCTIONS """
    # get other junctions [note: partial_edges is an undirected graph]
    for node_idx in partial_nodes:
        degree = sum([1 for row,_ in partial_edges.T if node_idx in row])
        if degree >=3 and node_idx not in junctions:
            junctions.append(node_idx)
            frag_nodes.append(len(frag_nodes))
            atom2frag_edges.append([node_idx, frag_nodes[-1]])
            junc2frag.append([node_idx, len(frag_nodes)-1])

    # for each junction get the shortest of all shortest paths between other junction nodes
    nx_graph.add_edges_from(partial_edges.t().tolist())
    shortest_idx = [None for _ in range(len(junctions))]
    paths = [None for _ in range(len(junctions))]
    for i in range(len(junctions)):
        shortest = float('inf')
        for j in range(i+1,len(junctions)): 
            try:
                path = nx.shortest_path(nx_graph, source=junctions[i], target=junctions[j])
                length = nx.shortest_path_length(nx_graph, source=junctions[i], target=junctions[j])
                if length < shortest:
                    shortest = length
                    shortest_idx[i] = j
                    paths[i] = path
            except nx.NetworkXNoPath:
                pass
    # have to separately handle the last fragment
    shortest = float('inf')
    for j in range(len(junctions)-1):
        try:
            path = nx.shortest_path(nx_graph, source=junctions[-1], target=junctions[j])
            length = nx.shortest_path_length(nx_graph, source=junctions[-1], target=junctions[j])
            if length < shortest:
                shortest = length
                shortest_idx[-1] = j
                paths[-1] = path
        except nx.NetworkXNoPath:
            pass


    for i,junc_idx in enumerate(junctions):
        assert junc2frag[i][0] == junc_idx, 'Junctions and junctions2frag are not aligned'
        frag_idx = junc2frag[i][1]
        if shortest_idx[i] is not None:
            if paths[i] is not None:
                for j in range(len(paths[i])):
                    atom2frag_edges += [[paths[i][j], frag_idx]]

            if paths[i] is not None and len(paths[i]) > 1:
                idx = shortest_idx[i]
                targets = [target for source, target in atom2frag_edges if source == idx]
                frag2frag_edges += [[frag_idx, target] for target in targets]
        # remove duplicates [note: directed graph]
        frag2frag_edges = list(set(map(tuple, frag2frag_edges)))

    """ SINGLETONS """
    # handle case where we have a single path graph
    if len(rings)==0 and len(junctions)==0:
        for node_idx in partial_nodes:
            degree = sum([1 for row,col in partial_edges.T if node_idx in row])
            if degree==1:
                singletons.append(node_idx)
        assert len(singletons)==2, 'Singletons are not 2'
        path = nx.shortest_path(nx_graph, source=singletons[0], target=singletons[1])
        frag_nodes.append(len(frag_nodes))
        atom2frag_edges += [[path[j], frag_nodes[-1]] for j in range(len(path))] 

    else:
        # find and add singletons to the fragments
        for node_idx in partial_nodes:
            degree = sum([1 for row,col in partial_edges.T if node_idx in row])
            if degree==1:
                singletons.append(node_idx)
                if node_idx not in junctions: # if a singleton is not a junction
                    frag_nodes.append(len(frag_nodes))
                    atom2frag_edges.append([node_idx, frag_nodes[-1]])
                    sing2frag.append([node_idx, len(frag_nodes)-1])
                else:
                    sing2frag.append([node_idx, junc2frag[junctions.index(node_idx)][1]])

        # get all shortest paths between singleton nodes and junction nodes
        shortest_idx = [None for _ in range(len(singletons))]
        paths = [None for _ in range(len(singletons))]
        for i in range(len(singletons)):
            shortest = float('inf')
            for j in range(len(junctions)):
                try:
                    path = nx.shortest_path(nx_graph, source=singletons[i], target=junctions[j])
                    length = nx.shortest_path_length(nx_graph, source=singletons[i], target=junctions[j])
                    if length < shortest:
                        shortest = length
                        shortest_idx[i] = j
                        paths[i] = path
                except nx.NetworkXNoPath:
                    pass


        for i,sing_idx in enumerate(singletons):
            assert sing2frag[i][0] == sing_idx, 'Singletons and sing2frag are not aligned'
            frag_idx = sing2frag[i][1]
            if paths[i] is not None:
                for j in range(len(paths[i])):
                    atom2frag_edges += [[paths[i][j], frag_idx]]
            if paths[i] is not None and len(paths[i]) > 1:
                idx = shortest_idx[i]
                targets = [target for source, target in atom2frag_edges if source == idx]
                frag2frag_edges += [[frag_idx, target] for target in targets]
        # remove duplicates [note: directed graph]
        frag2frag_edges = list(set(map(tuple, frag2frag_edges)))



    """ Update graph """
    # reduce duplicate frag indices
    graph_out = BipartitePair()
    graph_out.y = graph.y
    graph_out.x_a = graph.x
    graph_out.x_f = torch.zeros(len(frag_nodes)).to(torch.long).view(-1,1)
    graph_out.edge_index_a = graph.edge_index
    graph_out.edge_index_af = torch.tensor(atom2frag_edges).t()

    frag2frag_index = torch.tensor(frag2frag_edges).t()
    if len(frag2frag_index) > 0:
        frag2frag_index = to_undirected(frag2frag_index)
    else:
        frag2frag_index = torch.tensor([[0,0]]).t()

    graph_out.edge_index_f = frag2frag_index
    graph_out.pos_a = graph.pos
    graph_out.edge_attr_a = torch.zeros((graph_out.edge_index_a.shape[1], 1)).to(torch.long).view(-1)
    graph_out.edge_attr_f = torch.zeros((graph_out.edge_index_f.shape[1], 1)).to(torch.long).view(-1)
    graph_out.edge_attr_af = torch.zeros((graph_out.edge_index_af.shape[1], 1)).to(torch.long).view(-1)



    """ Verify """

    for frag_idx in frag_nodes:
        assert frag_idx in graph_out.edge_index_af[1], f'Frag nodes and atom2frag_edges are not aligned, {frag_idx}, {atom2frag_edges}, {graph.smiles}'
        #assert frag_idx in graph_out.edge_index_f[0], f'Frag nodes and frag2frag_edges are not aligned, {frag_idx}, {frag2frag_edges}, {graph.smiles}'
        #assert frag_idx in graph_out.edge_index_f[1], f'Frag nodes and frag2frag_edges are not aligned, {frag_idx}, {frag2frag_edges}, {graph.smiles}'


    return graph_out

def remove_similar_edges(edge_index1, edge_index2):
    edges1_set = set(map(tuple, edge_index1.t().tolist()))
    edges2_set = set(map(tuple, edge_index2.t().tolist()))
    remaining_edges = edges1_set - edges2_set
    remaining_edge_index = torch.tensor(list(remaining_edges)).t()
    return remaining_edge_index

def qm9_to_mol(data):
    mol = Chem.RWMol()
    # Add atoms to the molecule using data.z (atomic numbers)
    for atomic_num in data.z:
        atom = Chem.Atom(int(atomic_num.item()))  # Convert to RDKit atom
        mol.AddAtom(atom)

    # Add bond information based on distance thresholds or predefined bond data
    # Example: adding bonds based on distance threshold (simple nearest neighbor)
    threshold = 1.6  # Threshold distance for bond formation

    for i in range(len(data.pos)):
        for j in range(i + 1, len(data.pos)):
            dist = torch.norm(data.pos[i] - data.pos[j]).item()
            if dist < threshold:
                mol.AddBond(i, j, Chem.BondType.SINGLE)  # Add single bond for simplicity

    # Convert to a Mol object
    mol = mol.GetMol()
    data.mol = mol
    return data

if __name__ == '__main__':
    from tqdm import tqdm
    qm9_dataset = QM9(root='data/')
    pbar = tqdm(total=len(qm9_dataset))
    for i,data in enumerate(qm9_dataset):
        data = qm9_to_mol(data)
        pbar.update(1)
        data = Fragmentation()(data)
