import argparse
import torch
from torch_geometric.loader import DataLoader
import numpy as np

from Utils.utils import check_task, load_model, detect_exp_setting, detect_motif_nodes, show, GC_vis_graph
from Utils.metrics import efidelity
from Utils.datasets import get_dataset, get_graph_data
import torch.nn.functional as F
import copy
from torch_geometric.nn import global_mean_pool
from itertools import product
import time

from fast_pytorch_kmeans import KMeans
import mixem
from statistics import mean
from numpy import linalg as LA
import pickle

from GNN_Models.graph_sage import GraphSAGE
from st import STNet, train_stnet, sep_train_stnet

def get_global_exps(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataname = args.dataset
    # hidden = args.hidden
    max_nodes = args.max_nodes
    into_st = args.into_st
    into_class = args.into_class
    st_hop=args.nlayer

    max_graph = 5000

    task_type = check_task(dataname)
    dataset = get_dataset(dataname)
    try:dataset.print_summary()
    except AttributeError: pass

    try:n_fea, n_cls = dataset.num_features, dataset.num_classes 
    except AttributeError: n_fea, n_cls = dataset.num_features, 2
    explain_ids = detect_exp_setting(dataname, dataset)
    motif_nodes_number = detect_motif_nodes(dataname)
    gnn_model = load_model(dataname, args.gnn, n_fea, n_cls)
    gnn_model.eval()
    print(f"GNN Model Loaded. {dataname}, {task_type}. \nsize of Motif: {motif_nodes_number}. num of samples to explain: {len(explain_ids)}")
    print(f'Dataset={dataname}-{args.gnn}-lclus{args.clusters}-gclus{args.into_st}-{args.local_cluster}-{args.lmda}')

    gnn_modules = [mod for mod in gnn_model.children()]

    if task_type == "GC":
        
        loader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
        if args.local_cluster == 'kmeans':
            kmeans = KMeans(n_clusters=args.clusters, mode='euclidean', verbose=0)
        else:
            pass
            
        st_dict = {}
        global_st_dict = {}
        global_st_dict_rev = {}
        global_id = 0

        g_emb={a:[] for a in range(n_cls)}
        all_ys = []

        hash_model = GraphSAGE(n_fea, n_cls, 3, 32).cuda()
        hash_model.reset_parameters()
        hash_model.eval()

        start = time.time()
        
        for i, d in enumerate(loader):
            # if i not in explain_ids: continue 

            # if i <4: continue
            # if i > 200: break

            d = d.to(device)
            logits = gnn_model(d)[0]
            if torch.argmax(logits) != int(d.y): continue 
            all_ys.append(int(d.y))

            # if int(d.y)!=0: continue
            # GC_vis_graph(d.x, d.edge_index, Hedges=[], good_nodes=None, datasetname=dataname)
            # show()
            # continue

            g_emb[int(d.y)].append(gnn_model.get_gemb(d))
            
            print(f'\n{i}, y={int(d.y)}')
            st_dict_rev = {}
            for im, m in enumerate(gnn_model.get_hid_repr(d,st_hop)):
                st_emb_m = m.cpu().detach()
                hsh = i*max_graph+im
                info = [d.x.detach(), d.edge_index.detach(), d.y.detach()]
                st_dict_rev[st_emb_m]=hsh
                st_dict[hsh]=info

                # print(im, st_emb_m)
    
            all_embs = list(st_dict_rev.keys())
            id2hash={j:st_dict_rev[all_embs[j]] for j in range(len(all_embs))}
            torch.set_printoptions(precision=4)
            if args.local_cluster == 'kmeans':
                labels, dist, centroids = kmeans.fit_predict(torch.stack(all_embs).detach())
                _,label_idx = torch.unique(labels,return_inverse=True)
            else: 
                n_clus=args.clusters
                hidden = len(all_embs[0])
                weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                    mixem.distribution.MultivariateNormalDistribution((mid+1)*(np.ones(hidden).astype(float)), np.identity(hidden)) for mid in range(n_clus)
                ])

                # np_embs = np.array(torch.stack(all_embs).detach())
                # choices = np.array(range(np_embs.shape[0]))
                # choices = np.random.choice(choices,n_clus)
                # weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                #     mixem.distribution.MultivariateNormalDistribution(np_embs[choices[mid]], 3*np.identity(hidden)) for mid in range(n_clus)
                # ])

                # weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                #     mixem.distribution.MultivariateNormalDistribution((np.random.rand(hidden)-0.5), 2*np.identity(hidden)) for mid in range(n_clus)
                # ])
                
                centroids = [torch.nan_to_num(torch.tensor(dis.mu).float()) for dis in distributions]
                all_dist = mixem.probability(np.array(torch.stack(all_embs).detach()), weights, distributions)
                norm_all_dist = (all_dist-np.min(all_dist))/(np.max(all_dist)-np.min(all_dist))
                belong_clusters = np.nonzero(norm_all_dist)
            
            # GC_vis_graph(d.x, d.edge_index, Hedges=[], good_nodes=None, datasetname=dataname)
            # show()
            
            for clus in range(args.clusters):
                if args.local_cluster == 'kmeans':
                    all_idx = torch.nonzero(label_idx==clus).view(-1)
                else: 
                    all_idx = belong_clusters[0][np.nonzero(belong_clusters[1]==clus)[0]]

                if len(all_idx)>1:
                    # print(clus, len(all_idx.cpu()), torch.mean(dist[all_idx].cpu()))
                    # print(centroids[clus])

                    st_edge_count={eid:0 for eid in range(int(d.edge_index.shape[1]))}
                    for k in all_idx:
                        _d=st_dict[id2hash[int(k)]]
                        im = id2hash[int(k)]%max_graph
                        st_edge_id_set = get_lhop_edge_id(_d[1], im, st_hop)
                        for eid in st_edge_id_set: st_edge_count[eid]+=1
                        # print(f'{int(clus),im}', end=", ")
                    # print("\n")

                    max_count = max(list(st_edge_count.values()))
                    st_id=[int(a) for a in st_edge_count if st_edge_count[a]>=min(max_count, 5)]
                    st_edges = _d[1][:,st_id]

                    st_nodes = torch.unique(st_edges).tolist()
                    st_resort = {nd:ndid for ndid, nd in enumerate(st_nodes)}
                    clus_x = _d[0][st_nodes]
                    val_if_not_shown = 0  # 0 can be any other number within dtype range
                    clus_edge_index = st_edges.cpu().apply_(lambda val: st_resort.get(val, val_if_not_shown)).to(device)

                    if args.local_agg=='centroid':
                        local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                        local_emb = torch.cat([centroids[clus].cpu().detach(), local_emb])
                    else: 
                        local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                        local_emb = torch.cat([local_emb, centroids[clus].cpu().detach()])

                    global_st_dict_rev[local_emb]=global_id
                    global_st_dict[global_id]=[clus_x, clus_edge_index, len(all_idx), len(all_ys)-1]
                    global_id+=1
                    # print(global_st_dict_rev)

                    # print(centroids[clus].cpu().detach(), global_id)
                    # GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
                    # show()

                    
            # print(f'-------------------------')

        print(f' ... Local clustering takes {time.time()-start} seconds')

        kmeans_global = KMeans(n_clusters=into_st, mode='euclidean', verbose=1)
        all_embs = list(global_st_dict_rev.keys())
        labels, dist, centroids = kmeans_global.fit_predict(torch.stack(all_embs).detach())
        torch.set_printoptions(sci_mode=False)
        _,label_idx = torch.unique(labels,return_inverse=True)
        clus2graph={}
        hashedclus={}
        if args.train_mode=='separate': 
            mask = torch.zeros(len(all_ys), into_st).to(device).float()
            keeprows = []
        for clus in range(into_st):
            all_idx = torch.nonzero(label_idx==clus).view(-1)
            if len(all_idx)<=1: continue

            closest_id = int(all_idx[int(torch.argmin(torch.abs(dist[all_idx])))])

            clusg = global_st_dict[closest_id]
            hash_model.eval()
            hash_emb = str(gnn_model.get_graph_emb(clusg[0],clusg[1]).detach().cpu().view(-1))
            if hash_emb in hashedclus: 
                # prev_len = len(hashedclus[hash_emb])
                # cur_len=len([global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()])
                # clus2graph[centroids[clus]]=(clusg[:2])+[hash_emb]
                hashedclus[hash_emb]+=[global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()]
                print('old', clus, len(hashedclus[hash_emb]), dist[closest_id].cpu(), "mean dist", torch.mean(torch.abs(dist[all_idx])))
                continue
            else: 
                hashedclus[hash_emb]=[global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()]
                print(clus, len(hashedclus[hash_emb]), dist[closest_id].cpu(), "mean dist", torch.mean(torch.abs(dist[all_idx])))
            
            clus2graph[centroids[clus]]=clusg[:2]+[hash_emb]

            keeprows.append(clus)

            if args.train_mode=='separate': 
                for gid in global_st_dict.keys():
                    if gid in all_idx:
                        dat_id = global_st_dict[gid][-1]
                        mask[dat_id][clus]=1
        
        mask = mask[:,keeprows]

            # [clus_x, clus_edge_index] = global_st_dict[closest_id]
            # GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
            # show()

            # if len(clus_x) < 5: continue
            # print(f'...')
            # count = 0
            # for k in all_idx:
            #     [clus_x, clus_edge_index] = global_st_dict[int(k)]
            #     GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
            #     show()
            #     if count>3:break 
            #     count+=1

        if args.train_mode=='all':
            for tcls in range(n_cls):
                clus_emb = torch.stack(list(clus2graph.keys())).to(device)
                thres_cut = clus_emb.shape[1]//2
                clus_emb = clus_emb[:,:thres_cut]
                mean_gemb = torch.mean(torch.stack(g_emb[tcls]).to(device), dim=0).view(1,-1)
                stnet = STNet(clus_emb.shape[0], gnn_modules[-2:])
                # st_weights = train_stnet(stnet, clus_emb, mean_gemb, args).detach()[0]

                y= torch.tensor([tcls]).view(-1).long().to(device)
                st_weights = train_stnet(stnet, clus_emb, y, args).detach()[0]
                # st_weights/= max(torch.abs(st_weights))

                print(f'\nclass={tcls}, {mean_gemb}')

                save_dict = {}
                global_concepts=[]

                clus_embs=list(clus2graph.keys())
                for j, weight in enumerate(st_weights):
                    emb = clus_embs[j]
                    print(f'class={tcls}, {j}: {weight}')
                    [clus_x, clus_edge_index, hash_emb] = clus2graph[emb]

                    save_dict[j]=[weight, hashedclus[hash_emb]]
                    global_concepts.append(clus2graph[emb][:2])

                    if args.plot==1:
                        GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
                        show()

                path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
                with open(path+'global_concepts'+".pkl", 'wb') as handle:
                    pickle.dump(global_concepts, handle, protocol=pickle.HIGHEST_PROTOCOL)
                with open(path+'concepts-data'+".pkl", 'wb') as handle:
                    pickle.dump(save_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

        else:
            
            # num_train = int(0.8 * len(dataset))
            # num_eval = int(0.1 * len(dataset))
            # num_test = len(dataset) - num_train - num_eval

            # train_dataset, val_dataset, test_dataset = random_split(dataset, lengths=[num_train, num_eval, num_test],
            #                                     generator=torch.Generator().manual_seed(1234))

            # train_loader = DataLoader(train_dataset, batch_size=args.batch_size,  
            #                         shuffle=True)
            # val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
            #                         shuffle=False)
            # test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
            #                         shuffle=False)

            for tcls in range(n_cls):
                all_ys = torch.tensor(all_ys).long().to(device)
                clus_emb = torch.stack(list(clus2graph.keys())).to(device)
                thres_cut = clus_emb.shape[1]//2
                clus_emb = clus_emb[:,:thres_cut]
                mean_gemb = torch.mean(torch.stack(g_emb[tcls]).to(device), dim=0).view(1,-1)
                stnet = STNet(clus_emb.shape[0], gnn_modules[-2:])

                val_all_ys=torch.tensor(val_all_ys).long().to(device)
                val_care_idx=(val_all_ys==tcls)
                test_all_ys = torch.tensor(test_all_ys).long().to(device)
                test_care_idx=(test_all_ys==tcls)

                f_val_all_ys=torch.tensor(f_val_all_ys).long().to(device)
                f_val_care_idx=(f_val_all_ys==tcls)
                f_test_all_ys = torch.tensor(f_test_all_ys).long().to(device)
                f_test_care_idx=(f_test_all_ys==tcls)

                tcls_ys = [all_ys[care_idx], val_all_ys[val_care_idx], test_all_ys[test_care_idx]]
                tcls_masks = [train_mask[care_idx], val_mask[val_care_idx], test_mask[test_care_idx]]
                f_tcls_ys = [all_ys[care_idx], f_val_all_ys[f_val_care_idx], f_test_all_ys[f_test_care_idx]]
                f_tcls_masks = [train_mask[care_idx], f_val_mask[f_val_care_idx], f_test_mask[f_test_care_idx]]

                st_weights = sep_train_stnet(stnet, clus_emb, tcls_ys, args, tcls_masks, f_tcls_ys, f_tcls_masks).detach()[0]

                # tcls_ys = all_ys[all_ys==tcls]
                # st_weights = sep_train_stnet(stnet, clus_emb, tcls_ys, args, mask[all_ys==tcls]).detach()[0]
                # st_weights = sep_train_stnet(stnet, clus_emb, tcls_ys, args, mask[all_ys==tcls], train_loader, val_loader, test_loader, gnn_model).detach()[0]
                
                print(f'\nclass={tcls}, {mean_gemb}')

                save_dict = {}
                global_concepts=[]

                clus_embs=list(clus2graph.keys())
                for j, weight in enumerate(st_weights):
                    emb = clus_embs[j]
                    print(f'class={tcls}, {j}: {weight}')
                    [clus_x, clus_edge_index, hash_emb] = clus2graph[emb]

                    save_dict[j]=[weight, hashedclus[hash_emb]]
                    global_concepts.append(clus2graph[emb][:2])

                    if args.plot==1:
                        GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
                        show()

                path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-{args.local_agg}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
                with open(path+'global_concepts'+".pkl", 'wb') as handle:
                    pickle.dump(global_concepts, handle, protocol=pickle.HIGHEST_PROTOCOL)
                with open(path+'concepts-data'+".pkl", 'wb') as handle:
                    pickle.dump(save_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)


def get_neighbor_edge_id(edge_index, node):
    out,next_nodes=[],set()
    for i, e0 in enumerate(edge_index[0]):
        if int(e0)==int(node):
            out.append(i)
            next_nodes.add(int(edge_index[1][i]))
    return out, list(next_nodes)

def get_lhop_edge_id(edge_index, node, L): 
    nodes, all_out=[node], []
    for l in range(L):
        lnodes=[]
        for nd in nodes:
            out, next_nodes=get_neighbor_edge_id(edge_index, nd)
            all_out+=out
            lnodes+=next_nodes
        nodes = list(set(lnodes))
    return set(all_out)

def build_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='ba_2motifs')
    parser.add_argument('--gnn', type=str, default='gin')
    parser.add_argument('--local_cluster', type=str, default='kmeans', choices=['kmeans', 'em'])
    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--max_nodes', type=int, default=200)
    parser.add_argument('--into_st', type=int, default=6)
    parser.add_argument('--into_class', type=int, default=0)
    parser.add_argument('--nlayer', type=int, default=3)
    parser.add_argument('--clusters', type=int, default=3)
    parser.add_argument('--lmda', type=float, default=1)

    parser.add_argument('--local_agg', type=str, default='sum', choices=['centroid', 'sum'])
    parser.add_argument('--train_mode', type=str, default='separate', choices=['separate', 'all'])

    parser.add_argument('--plot', type=int, default=0)
    
    return parser.parse_args()

def load_saved_results(args):
    dataset = get_dataset(args.dataset)
    try:n_fea, n_cls = dataset.num_features, dataset.num_classes 
    except AttributeError: n_fea, n_cls = dataset.num_features, 2
    explain_ids = detect_exp_setting(args.dataset, dataset)

    for tcls in range(n_cls):
        path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-{args.local_agg}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
        with open(path+'global_concepts'+".pkl", 'rb') as handle:
            global_concepts = pickle.load(handle)
        with open(path+'concepts-data'+".pkl", 'rb') as handle:
            save_dict = pickle.load(handle)
        
        covered_idx=[]
        high_conf_covered_idx = []
        for j, glconc in enumerate(global_concepts):
            clus_x, clus_edge_index=glconc[0], glconc[1]
            covered_idx+=save_dict[j][1]
            if abs(save_dict[j][0])>=0.3: high_conf_covered_idx+=save_dict[j][1]
            print(tcls, j, save_dict[j][0], len(set(save_dict[j][1])))
            if args.plot ==1:
                GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=args.dataset)
                show()
        print(f"covered: {len(set(covered_idx))/len(explain_ids)}, covered with high condidence: {len(set(high_conf_covered_idx))/len(set(covered_idx))}\n")
    
if __name__ == "__main__":

    import warnings
    warnings.filterwarnings("ignore")
    args = build_args()
    try: load_saved_results(args)
    except FileNotFoundError:
        get_global_exps(args)
    print("done")


