import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential
from Utils.utils import show, GC_vis_graph, G_alpha_vis
from Utils.metrics import efidelity, fid

import numpy as np
from fast_pytorch_kmeans import KMeans
from statistics import mean

class STNet(torch.nn.Module):
    def __init__(self, hidden, n_classes, npatterns, classifier) -> None:
        super().__init__()
        self.n_classes = n_classes
        self.npatterns = npatterns
        self.mlp = Sequential(
                Linear(hidden, 64),
                # ReLU(inplace=False),
                # Linear(64, 64),
                ReLU(inplace=False),
                Linear(64, 1),
                # ReLU(inplace=False),
                # BN(1),
            )
        self.classifier = classifier
        for clf in self.classifier:
            clf.eval()
        self.init_params()
    
    def init_params(self):
        return
        torch.nn.init.xavier_uniform(self.weight)

    def forward(self, st_emb):
        self.weight = self.mlp(st_emb).view(1,-1)
        out = torch.mm(self.weight, st_emb)
        # out = torch.mm(F.normalize(self.weight, p=self.npatterns, dim=-1), st_emb)
        # out = torch.mm(F.normalize(self.weight, p=self.npatterns, dim=-1), st_emb)
        for clf in self.classifier[:-1]:
            out = F.relu(clf(out))
        out = self.classifier[-1](out)
        # out = F.dropout(out, p=0.5, training=False)
        return out
        # return F.softmax(out, dim=-1)
        # print(F.log_softmax(out, dim=-1), out)
        return F.log_softmax(out, dim=-1)

class LSTNet(torch.nn.Module):
    def __init__(self, n_subtree, hconvs) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.rand(1,n_subtree)-0.5)
        self.hconvs=hconvs
        for conv in self.hconvs:
            conv.eval()
    
    def forward(self, lst_emb):
        out = torch.mm(self.weight, lst_emb)
        for conv in self.hconvs:
            out=F.relu(conv(out, edge_index=[], do_prop=False))
        return out

def find_L_hop_edges(node, L, edge_index):
    if L==0: return []
    edges = []
    last_hop_nodes = [node]
    for k in range(L):
        this_hoop_edges = torch.nonzero(sum([edge_index[0] == lhn for lhn in last_hop_nodes])>0).view(-1)
        edges += this_hoop_edges.cpu().tolist()
        if len(edges) == 0: break
        last_hop_nodes=list(set(edge_index[1,this_hoop_edges].cpu().tolist()))
    return list(set(edges))

def find_L_hop_nodes(node, L, edge_index):
    if L==0: return []
    all_nodes = [node]
    last_hop_nodes = [node]
    for k in range(L):
        try: 
            this_hoop_edges = torch.nonzero(sum([edge_index[0] == lhn for lhn in last_hop_nodes])>0).view(-1)
        except TypeError:
            # print(f"node={node}")
            # print(f'layer={k}, last_hop_nodes={last_hop_nodes}')
            # print(sum([edge_index[0] == lhn for lhn in last_hop_nodes])>0)
            # print(f'isolated nodes detected.')
            return [node]
        last_hop_nodes=list(set(edge_index[1,this_hoop_edges].cpu().tolist()))
        all_nodes+=last_hop_nodes
    return all_nodes

def filter_tuples(dist, this_conf):
    filtered_keys = [key for key in dist]
    if len(filtered_keys)==0: return None
    unq_nds=[]
    for key in filtered_keys:
        if key[0] not in unq_nds: 
            unq_nds.append(key[0])
        if key[1] not in unq_nds: 
            unq_nds.append(key[1])
    groups,cover_nds=[],[]
    for nd in unq_nds:
        tmp_group=[nd]
        if nd in cover_nds: continue
        for key in filtered_keys:
            if nd == key[0]: tmp_group.append(key[1])
            elif nd==key[1]: tmp_group.append(key[0])
        cover_nds+=tmp_group
        groups.append(tmp_group)
    scores = []
    for group in groups:
        score = sum([this_conf[nd] for nd in group])/len(group)
        scores.append(score)
    return groups, scores

def cluster_union(top_st_indices, st_norm_weights, emb_global_st_dict_list, loader, dataname, args, gnn_model, device, into_class):
    # print(top_st_indices)
    # print(st_norm_weights[top_st_indices])

    L=args.plot_layer
    top_L_st_tuple, top_L_st_weights = [], []
    for j in top_st_indices:
        max_into_hop = max([emb_global_st_dict_list[j][-1][-1][m][0] for m in range(len(emb_global_st_dict_list[j][-1][-1]))])
        into_hop = min([emb_global_st_dict_list[j][-1][-1][m][0] for m in range(len(emb_global_st_dict_list[j][-1][-1]))])
        into_st_list = [emb_global_st_dict_list[j][-1][-1][m] for m in range(len(emb_global_st_dict_list[j][-1][-1])) if emb_global_st_dict_list[j][-1][-1][m][0]==into_hop]
        if max_into_hop==L:
            # if st_norm_weights[j] > args.h_thres:
            if torch.abs(st_norm_weights[j]) > args.h_thres:
                for into_st in into_st_list:
                    top_L_st_tuple.append(into_st)
                    top_L_st_weights.append(st_norm_weights[j])
    opt_id_topL = [m[-1] for m in top_L_st_tuple]

    ids, nds, all_embs, conf =[],[],[],[]
    all_count, explain_cover=0,0
    
    all_fid, all_infid = [],[]

    for i, d in enumerate(loader): 
        d = d.to(device)
        logits = gnn_model(d)[0]
        if torch.argmax(logits) != int(d.y): continue 
        # if into_class != int(d.y): continue
        all_count+=1

        if i not in opt_id_topL: continue
        explain_cover+=1

        scores = [0.0]*d.x.shape[0]
        this_nds,this_embs,this_conf = [],[],{}
        for pid, p in enumerate(opt_id_topL):
            if i==p:
                nd = top_L_st_tuple[pid][-2]
                this_nds.append(nd)
                this_embs.append(gnn_model.get_hid_repr(d,L)[nd])
                this_conf[nd] = top_L_st_weights[pid].cpu().tolist()
                scores[nd]=this_conf[nd]
        
        # G_alpha_vis(d.x, d.edge_index, scores, dataname)
        # show()
        # continue
        
        choose_nodes, edges = [], []
        non_edges = set(list(range(d.edge_index.shape[1])))
        non_nodes = set(list(range(d.x.shape[0])))
        if len(this_nds)>1:
            dist={}
            for j,nd0 in enumerate(this_nds):
                for k, nd1 in enumerate(this_nds):
                    if j<=k:continue
                    if nd1 in find_L_hop_nodes(nd0, L, d.edge_index):
                        dist[(nd0,nd1)]=1
                    # dist[(nd0,nd1)]=1
            if len(dist)<1: 
                nd_groups =[ [cnd] for cnd in this_nds ]
                group_scores = [this_conf[cnd] for cnd in this_nds]
            else: nd_groups, group_scores = filter_tuples(dist, this_conf)
            # print(i, nd_groups, group_scores)

            for mi, nd_group in enumerate(nd_groups):
                # all_embs.append(sum([gnn_model.get_hid_repr(d,L)[nd] for nd in nd_group]))
                all_embs.append(torch.mean(torch.stack([gnn_model.get_hid_repr(d,L)[nd] for nd in nd_group]),dim=0))
                # print(torch.mean(torch.stack([gnn_model.get_hid_repr(d,L)[nd] for nd in nd_group]),dim=0).shape)
                ids.append(i)
                nds.append(nd_group)
                conf.append(group_scores[mi])

                if into_class != int(d.y):
                    if group_scores[mi]<0:
                        choose_nodes += nd_group
                    else: 
                        non_nodes = non_nodes-set(nd_group)
                else:
                    if group_scores[mi]>0:
                        choose_nodes += nd_group
                    else: 
                        non_nodes = non_nodes-set(nd_group)
                for _, nd in enumerate(nd_group):
                    lhope = find_L_hop_edges(nd, L, d.edge_index)
                    if into_class != int(d.y):
                        if group_scores[mi]<0:
                            edges += lhope
                        else:
                            non_edges = non_edges-set(lhope)
                    elif into_class == int(d.y):
                        if group_scores[mi]>0:
                            edges += lhope
                        else:
                            non_edges = non_edges-set(lhope)
        else:
            all_embs.append(this_embs[-1])
            ids.append(i)
            nds.append(this_nds)
            conf.append(this_conf[this_nds[-1]])

            if into_class != int(d.y):
                if conf[-1]<0:
                    choose_nodes += this_nds
                else: 
                    non_nodes = non_nodes-set(this_nds)
            else:
                if conf[-1]>0:
                    choose_nodes += this_nds
                else: 
                    non_nodes = non_nodes-set(this_nds)

            lhope = find_L_hop_edges(nd, L, d.edge_index)
            if into_class != int(d.y):
                if conf[-1]<0:
                    edges = lhope
                else:
                    non_edges = non_edges-set(lhope)
            elif into_class == int(d.y):
                if conf[-1]>0:
                    edges = lhope
                else:
                    non_edges = non_edges-set(lhope)
        edges = list(set(edges))
        choose_nodes = list(set(choose_nodes))
        
        # if len(choose_nodes)>0:
        if len(edges)>0:
            (fid_minus,fid_prob), (fid_plus,infid_prob) = efidelity(edges, gnn_model, d, device)
            # (_,fid_minus), (_,fid_plus) = efidelity(edges, gnn_model, d, device)
            # (fid_minus,fid_prob), (fid_plus,infid_prob) = fid(choose_nodes, gnn_model, d, device, args.nlayers)
            all_fid.append(fid_minus)
            all_infid.append(fid_plus)
            # print(f'i={i}, Fidelity={fid_minus}, Infidelity={fid_plus}, logits={logits}, fid_prob={fid_prob}, infid_prob={infid_prob}')
            # print(f'i={i}, Fidelity={fid_minus}, logits={logits}, fid_prob={fid_prob}')
        else:
            (fid_minus,fid_prob), (fid_plus,infid_prob) = efidelity(list(non_edges), gnn_model, d, device)
            # (_,fid_minus), (_,fid_plus) = efidelity(list(non_edges), gnn_model, d, device)
            # (fid_minus,fid_prob), (fid_plus,infid_prob) = fid(list(non_nodes), gnn_model, d, device, args.nlayers)
            all_fid.append(fid_minus)
            all_infid.append(fid_plus)
            # print(f' >> not, i={i}, Fidelity={fid_minus}, Infidelity={fid_plus}, logits={logits}, fid_prob={fid_prob}, infid_prob={infid_prob}')
            edges = list(non_edges)

            # print(d.edge_index[:,edges])
        # if args.do_plot>0:
        #     color='red'
        #     GC_vis_graph(d.x, d.edge_index, Hedges=edges, good_nodes=None, datasetname=dataname, edge_color=color)
        #     show()

    cover_ratio = explain_cover/all_count
    fid = 1-sum(all_fid)/explain_cover
    infid = sum(all_infid)/explain_cover
    print(f'\n---------------------------')
    print(f'cover_ratio={cover_ratio}, all_count={all_count}, explain_cover={explain_cover}, fid={fid}, infid={infid}')
    # print(f'cover_ratio={cover_ratio}, all_count={all_count}, explain_cover={explain_cover}, fid={1-sum(all_fid)/explain_cover}')
    print(f'---------------------------\n')
    if args.do_plot>0:
        cluster_intersect(all_embs, ids, nds, conf, loader, args, dataname)

    return (cover_ratio, fid, infid)


def cluster_intersect(all_embs, ids, nds, conf, loader, args, dataname):
    L=args.plot_layer
    kmeans = KMeans(n_clusters=args.clusters, mode='euclidean', verbose=1)
    # kmeans._show=True
    labels, dist = kmeans.fit_predict(torch.stack(all_embs).detach())
    torch.set_printoptions(sci_mode=False)
    _,label_idx = torch.unique(labels,return_inverse=True)
    avg_conf = {}
    # print()
    for clus in range(args.clusters):
        all_idx = torch.nonzero(label_idx==clus).view(-1)
        # print(clus, all_idx.cpu(), -dist[all_idx].cpu())
        if len(all_idx)>0:
            # print(f'         {ids[int(all_idx[torch.argmin(-dist[all_idx].cpu())])]}')
            # print(f'{np.asarray(ids)[all_idx.cpu().tolist()]}')
            avg_conf[clus]=mean(np.asarray(conf)[all_idx.cpu().tolist()])
            if abs(avg_conf[clus])>0.1:
                print(clus, len(all_idx.cpu()), torch.mean(-dist[all_idx].cpu()))
                print(f'avg_conf={avg_conf[clus]}\n')

    if args.do_plot>0:
        for clus in range(args.clusters):
            cou = 0
            all_idx = torch.nonzero(label_idx==clus).view(-1)
            if len(all_idx)<30: continue
            for i, d in enumerate(loader):
                if i not in ids: continue
                labels_ids = [m for m in range(len(ids)) if ids[m]==i]
                labels = [int(label_idx[m]) for m in labels_ids]
                if clus not in labels: continue
                # print(labels)
                # print(clus)
                # print(labels_ids)
                # print(i, labels_ids[labels.index(clus)],"\n")
                nd = nds[labels_ids[labels.index(clus)]]
                print(f'cluster={clus}, conf={avg_conf[clus]}, y={int(d.y)}, i={i}, nd={nd}')
                edges = []
                for node in nd:
                    edges += find_L_hop_edges(node, L, d.edge_index)
                color='red'
                if avg_conf[clus]<0: color='green'
                GC_vis_graph(d.x, d.edge_index, Hedges=edges, good_nodes=None, datasetname=dataname, edge_color=color)
                show()
                cou+=1
                if cou>2: break
    return 

def subgraph_extractor(top_st_indices, st_norm_weights, emb_global_st_dict_list, loader, dataname, args, gnn_model, device, into_class):
    arr = [print(int(j), emb_global_st_dict_list[j][-1],"\n") for j in top_st_indices if emb_global_st_dict_list[j][-1][0][0]<500]
    print(len(arr))
    print(top_st_indices)
    print(st_norm_weights[top_st_indices])

    L=args.plot_layer
    top_L_st_tuple, top_L_st_weights = [], []
    rm_st_tuple, rm_st_weights = [], []
    for j in top_st_indices:
        max_into_hop = max([emb_global_st_dict_list[j][-1][-1][m][0] for m in range(len(emb_global_st_dict_list[j][-1][-1]))])
        into_hop = min([emb_global_st_dict_list[j][-1][-1][m][0] for m in range(len(emb_global_st_dict_list[j][-1][-1]))])
        into_st_list = [emb_global_st_dict_list[j][-1][-1][m] for m in range(len(emb_global_st_dict_list[j][-1][-1])) if emb_global_st_dict_list[j][-1][-1][m][0]==into_hop]
        if max_into_hop==L:
            # if st_norm_weights[j] > args.h_thres:
            # if st_norm_weights[j] < -1*args.h_thres:
            if torch.abs(st_norm_weights[j]) > args.h_thres:
                for into_st in into_st_list:
                    top_L_st_tuple.append(into_st)
                    top_L_st_weights.append(st_norm_weights[j])
                    # print("Adding", into_st, st_norm_weights[j])
        else:
            if st_norm_weights[j] < args.l_thres:
            # if torch.abs(st_norm_weights[j]) <= args.l_thres:
                for into_st in into_st_list:
                    rm_st_tuple.append(into_st)
                    rm_st_weights.append(st_norm_weights[j])
                    # print("Removing", into_st, st_norm_weights[j])

    opt_id_topL = [m[-1] for m in top_L_st_tuple]
    opt_id_rm = [m[-1] for m in rm_st_tuple]

    all_count, explain_cover=0,0

    for i, d in enumerate(loader): 
        
        d = d.to(device)
        logits = gnn_model(d)[0]
        if into_class != int(d.y): continue
        if torch.argmax(logits) != int(d.y) or i not in opt_id_topL: 
            GC_vis_graph(d.x, d.edge_index, Hedges=[], good_nodes=None, datasetname=dataname, edge_color='red')
            show()
        else: continue

        d = d.to(device)
        logits = gnn_model(d)[0]
        if torch.argmax(logits) != int(d.y): continue 

        if into_class != int(d.y): continue
        all_count+=1

        if i not in opt_id_topL: continue
        # if i not in opt_id_rm: continue
        # if i not in opt_id_topL and i not in opt_id_rm: continue

        if args.do_plot>0:
            print("\n------------\n", int(d.y))
            GC_vis_graph(d.x, d.edge_index, Hedges=[], good_nodes=None, datasetname=dataname)
        for pid, p in enumerate(opt_id_topL):
            if i==p:
                if args.do_plot>0:
                    print(top_L_st_weights[pid].cpu().tolist(), top_L_st_tuple[pid])
                nd = top_L_st_tuple[pid][-2]
                # if nd != 24: continue
                l = top_L_st_tuple[pid][0]
                edges = find_L_hop_edges(nd, l, d.edge_index)
                color = 'red' if top_L_st_weights[pid]>0 else 'blue'
                if args.do_plot>0: GC_vis_graph(d.x, d.edge_index, Hedges=edges, good_nodes=None, datasetname=dataname, edge_color=color)
        for pid, p in enumerate(opt_id_rm):
            if i==p:
                print(rm_st_weights[pid].cpu().tolist(), rm_st_tuple[pid])
                nd = rm_st_tuple[pid][-2]
                l = rm_st_tuple[pid][0]
                edges = find_L_hop_edges(nd, l, d.edge_index)
                # GC_vis_graph(d.x, d.edge_index, Hedges=edges, good_nodes=None, datasetname=dataname, edge_color='black')
        if args.do_plot>0: show()

        explain_cover+=1
    
    print(f'cover_ratio={explain_cover/all_count}, all_count={all_count}, explain_cover={explain_cover}')




