import argparse
import torch
from torch_geometric.loader import DataLoader
import numpy as np
from torch.utils.data import random_split

from Utils.utils import check_task, load_model, detect_exp_setting, detect_motif_nodes, show, GC_vis_graph
from Utils.metrics import efidelity
from Utils.datasets import get_dataset, get_graph_data
import torch.nn.functional as F
import copy
from torch_geometric.nn import global_mean_pool
from itertools import product
import time

from fast_pytorch_kmeans import KMeans
import mixem
from statistics import mean
from numpy import linalg as LA
import pickle

from GNN_Models.graph_sage import GraphSAGE
from st import STNet, sep_train_stnet, redetermine_pred, eval_acc

def get_global_exps(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataname = args.dataset
    # hidden = args.hidden
    into_st = args.into_st
    st_hop=args.nlayer

    task_type = check_task(dataname)
    dataset = get_dataset(dataname)
    try:dataset.print_summary()
    except AttributeError: pass

    try:n_fea, n_cls = dataset.num_features, dataset.num_classes 
    except AttributeError: n_fea, n_cls = dataset.num_features, 2
    explain_ids = detect_exp_setting(dataname, dataset)
    motif_nodes_number = detect_motif_nodes(dataname)
    gnn_model = load_model(dataname, args.gnn, n_fea, n_cls)
    gnn_model.eval()
    print(f"GNN Model Loaded. {dataname}, {task_type}. \nsize of Motif: {motif_nodes_number}. num of samples to explain: {len(explain_ids)}")
    print(f'Dataset={dataname}-{args.gnn}-lclus{args.clusters}-gclus{args.into_st}-{args.local_cluster}-{args.lmda}')

    gnn_modules = [mod for mod in gnn_model.children()]

    num_train = int(0.8 * len(dataset))
    num_eval = int(0.1 * len(dataset))
    # num_test = int(0.05 * len(dataset))
    num_test = len(dataset) - num_train - num_eval

    train_dataset, val_dataset, test_dataset = random_split(dataset, lengths=[num_train, num_eval, num_test],
                                            generator=torch.Generator().manual_seed(1234))

    train_loader = DataLoader(train_dataset, batch_size=1,  
                            shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1,
                            shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1,
                            shuffle=False)

    # original GNN acc:
    def _eval_acc(model, loader):
        model.eval()

        correct = 0
        length = 0
        for data in loader:
            data = data.to(device)
            
            with torch.no_grad():
                pred = model(data).max(1)[1]
                
            correct += pred.eq(data.y.view(-1)).sum().item()
            length += len(pred)
        return correct / length

    train_acc = _eval_acc(gnn_model, train_loader)
    val_acc = _eval_acc(gnn_model, val_loader)
    test_acc = _eval_acc(gnn_model, test_loader)
    print("+original GNN acc - train_acc:", train_acc, "- val_acc:", val_acc, "- test_acc:", test_acc)

    all_ys, _all_logits, global_st_dict, global_st_dict_rev, g_embs = local_clustering(train_loader, args, device, gnn_model, omit_wrong=False)

    path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-clus{args.clusters}-{args.local_agg}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
    # train
    kmeans_global, train_mask, clus2graph, hashedclus, keeprows = global_clustering(args, device, path, gnn_model, all_ys, global_st_dict, global_st_dict_rev)

    val_all_ys, _val_all_logits, val_global_st_dict, val_global_st_dict_rev, val_g_embs = local_clustering(val_loader, args, device, gnn_model, omit_wrong=False)
    val_mask = fit_global(args, device, gnn_model, kmeans_global, val_all_ys, val_global_st_dict, val_global_st_dict_rev, keeprows)

    test_all_ys, _test_all_logits, test_global_st_dict, test_global_st_dict_rev, test_g_embs = local_clustering(test_loader, args, device, gnn_model, omit_wrong=False)
    test_mask = fit_global(args, device, gnn_model, kmeans_global, test_all_ys, test_global_st_dict, test_global_st_dict_rev, keeprows)
    
    f_test_all_ys, f_test_global_st_dict, f_test_global_st_dict_rev, false_data = false_pred_local_clustering(test_loader, args, device, gnn_model)
    if len(f_test_all_ys)>0:
        f_test_mask, f_concepts = fit_global(args, device, gnn_model, kmeans_global, f_test_all_ys, f_test_global_st_dict, f_test_global_st_dict_rev, keeprows, f=True)


    def _train_stnet(model, data, y, orig_logits, args, masks, tcls, f_tcls_ys, f_tcls_masks, f_g_data, f_concepts):
        model.cuda()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        path = "saved_models/"+f'sep_st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-stnum{len(data)}-{args.local_cluster}-{args.lmda}'+".model"

        [y_train, y_val, y_test]=y
        [train_masks, val_masks, test_masks] = masks

        best_val_acc = 0
        best_loss = 9999
        for epoch in range(1, args.epochs):
            tot_loss = 0
            for i, mask in enumerate(train_masks):
                model.train()
                optimizer.zero_grad()
                mask = mask.view(1,-1)
                out = model(data, mask)
                if int(y_train[i])==tcls:
                    loss = F.nll_loss(out, y_train[i].view(-1))+float(args.lmda)*torch.norm(model.weight,p=2)
                else:
                    loss = F.nll_loss(out, ((y_train[i]+1)%2).view(-1))+float(args.lmda)*torch.norm(model.weight,p=2)
                loss.backward(retain_graph=True)
                optimizer.step()
                tot_loss+=loss.item()

            [train_acc, val_acc, test_acc], [train_prob, val_prob, test_prob] = eval_acc(masks, data, model, y, orig_logits)

            print(f"Epoch: {epoch} - traing loss: {tot_loss} - train_acc: {train_acc} - val_acc: {val_acc} - test_acc: {test_acc}")
            print(f"train_prob: {train_prob} - val_prob: {val_prob} - test_prob: {test_prob}")
            if val_acc >= best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), path)
                # early stop
                if abs(tot_loss-best_loss)<5e-3 or (train_acc==1.0 and val_acc==1.0 and test_acc==1.0): break
                best_loss = tot_loss

        model.load_state_dict(torch.load(path))
        model.eval()
        [train_acc, val_acc, test_acc], [train_prob, val_prob, test_prob] = eval_acc(masks, data, model, y, orig_logits)
        print(f"\n-------------------------\nFinally, traing loss: {tot_loss} - train_acc: {train_acc} - val_acc: {val_acc} - test_acc: {test_acc}\nnumber of train={len(masks[0])}, val={len(masks[1])}, test={len(masks[2])}")
        print(f"train_prob: {train_prob} - val_prob: {val_prob} - test_prob: {test_prob}")
            
        return

    st_model=[]
    for tcls in range(n_cls):
        all_ys = torch.tensor(all_ys).long().to(device)
        clus_emb = torch.stack(list(clus2graph.keys())).to(device)
        thres_cut = clus_emb.shape[1]//2
        clus_emb = clus_emb[:,:thres_cut]

        stnet = STNet(clus_emb.shape[0], gnn_modules[-2:])
        val_all_ys=torch.tensor(val_all_ys).long().to(device)
        test_all_ys = torch.tensor(test_all_ys).long().to(device)

        all_logits = torch.stack(_all_logits).to(device)
        val_all_logits = torch.stack(_val_all_logits).to(device)
        test_all_logits = torch.stack(_test_all_logits).to(device)

        if len(f_test_all_ys)>0:
            f_test_all_ys = torch.tensor(f_test_all_ys).long().to(device)
            f_test_care_idx=(f_test_all_ys==tcls)

        ys = [all_ys, val_all_ys, test_all_ys]
        logits = [all_logits, val_all_logits, test_all_logits]
        masks = [train_mask, val_mask, test_mask]
        
        if len(f_test_all_ys)>0:
            f_tcls_ys = f_test_all_ys[f_test_care_idx]
            f_tcls_masks = f_test_mask[f_test_care_idx]
        else:
            f_tcls_ys, f_tcls_masks = None,None

        _train_stnet(stnet, clus_emb, ys, logits, args, masks, tcls, f_tcls_ys, f_tcls_masks, false_data, f_concepts)
        st_model.append(copy.deepcopy(stnet))

    def f_eval_acc(masks, data, models, y, f_g_data, f_concepts):
        test_acc = []
        for j,d in enumerate(f_g_data):
            mask = masks[j].view(1,-1)
            concepts = f_concepts[j]
            for i, model in enumerate(models):
                f_weight = model.weight
                for ik, cpt in enumerate(concepts):
                    if len(cpt)>0:
                        score = mask.view(-1)[ik]*f_weight.view(-1)[ik]
                        if score != 0:
                            print(f'id={j}, y={int(y[j])}, Class={i}, score={score}')
                            # print(d.edge_index[:,cpt])
                            GC_vis_graph(d.x, d.edge_index, Hedges=cpt, good_nodes=None, datasetname=dataname)
                            show()
                out = model.abs_forward(data, mask)[0]
                pred = int(torch.argmax(out))
                test_acc.append(1*(pred==int(y[j])))
        
    if f_test_mask is not None:
        f_eval_acc(f_test_mask, clus_emb, st_model, f_test_all_ys, false_data, f_concepts)

def global_clustering(args, device, path, gnn_model, all_ys, global_st_dict, global_st_dict_rev):

    kmeans_global = KMeans(n_clusters=args.into_st, mode='euclidean', verbose=1)
    with open(path+'kmeans_global'+".pkl", 'wb') as handle:
        pickle.dump(kmeans_global, handle, protocol=pickle.HIGHEST_PROTOCOL)

    all_embs = list(global_st_dict_rev.keys())
    labels, dist, centroids = kmeans_global.fit_predict(torch.stack(all_embs).detach())
    torch.set_printoptions(sci_mode=False)
    _,label_idx = torch.unique(labels,return_inverse=True)
    clus2graph={}
    hashedclus={}
    if args.train_mode=='separate': 
        mask = torch.zeros(len(all_ys), args.into_st).to(device).float()
        keeprows = {}
        keep=[]
        count = 0
    for clus in range(args.into_st):
        all_idx = torch.nonzero(label_idx==clus).view(-1)

        if len(all_idx)<1: continue
        closest_id = int(all_idx[int(torch.argmin(torch.abs(dist[all_idx])))])
        clusg = global_st_dict[closest_id]
        hash_emb = str(gnn_model.get_graph_emb(clusg[0],clusg[1]).detach().cpu().view(-1))
        if hash_emb in hashedclus: 
            hashedclus[hash_emb]+=[global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()]
            real_clus = list(hashedclus.keys()).index(hash_emb)
            print('old', clus, len(hashedclus[hash_emb]), dist[closest_id].cpu(), "mean dist", torch.mean(torch.abs(dist[all_idx])))
            keeprows[clus]=real_clus
            continue
        else: 
            hashedclus[hash_emb]=[global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()]
            print(clus, len(hashedclus[hash_emb]), dist[closest_id].cpu(), "mean dist", torch.mean(torch.abs(dist[all_idx])))
            keeprows[clus]=count
            count+=1
            keep.append(clus)

        clus2graph[centroids[clus]]=clusg[:2]+[hash_emb]
        if args.train_mode=='separate': 
            for gid in all_idx:
                dat_id = global_st_dict[int(gid)][-1]
                mask[dat_id][clus]=global_st_dict[int(gid)][-2]
    mask = mask[:,keep]
    return kmeans_global, mask, clus2graph, hashedclus, keeprows

def fit_global(args, device, gnn_model, kmeans_global, all_ys, global_st_dict, global_st_dict_rev, keeprows, f=False):
    all_embs = list(global_st_dict_rev.keys())
    labels = kmeans_global.predict(torch.stack(all_embs).detach())
    _,label_idx = torch.unique(labels,return_inverse=True)
    if args.train_mode=='separate': 
        mask = torch.zeros(len(all_ys), len(set(keeprows.values()))).to(device).float()
    if f:
        concepts = [[[] for aa in range(len(set(keeprows.values())))] for ww in range(len(all_ys))]

    for clus in list(set(keeprows.keys())):
        all_idx = torch.nonzero(label_idx==clus).view(-1)
        if args.train_mode=='separate': 
            for gid in all_idx:
                dat_id = global_st_dict[int(gid)][-1]
                mask[dat_id][keeprows[clus]]=global_st_dict[int(gid)][-2]
                if f: concepts[dat_id][keeprows[clus]] = global_st_dict[int(gid)][1]
    if f:
        return mask, concepts
    return mask

def local_clustering(loader, args, device, gnn_model, omit_wrong=True):

    st_dict = {}
    global_st_dict = {}
    global_st_dict_rev = {}
    global_id = 0
    all_ys, all_logits = [],[]
    g_embs=[]

    max_graph = 5000

    start = time.time()
    for i, d in enumerate(loader): 

        # if i >300:break

        d = d.to(device)
        logits = gnn_model(d)[0]
        if omit_wrong:
            if torch.argmax(logits) != int(d.y): continue 
        all_ys.append(int(d.y))
        all_logits.append(logits)
        g_embs.append(gnn_model.get_gemb(d))

        # print(f'\n{i}, y={int(d.y)}')

        st_dict_rev = {}
        for im, m in enumerate(gnn_model.get_hid_repr(d,args.nlayer)):
            st_emb_m = m.cpu().detach()
            hsh = i*max_graph+im
            info = [d.x.detach(), d.edge_index.detach(), d.y.detach()]
            st_dict_rev[st_emb_m]=hsh
            st_dict[hsh]=info
        
        all_embs = list(st_dict_rev.keys())
        id2hash={j:st_dict_rev[all_embs[j]] for j in range(len(all_embs))}
        torch.set_printoptions(precision=4)
        if args.local_cluster == 'kmeans':
            kmeans = KMeans(n_clusters=args.clusters, mode='euclidean', verbose=0)
            labels, dist, centroids = kmeans.fit_predict(torch.stack(all_embs).detach())
            _,label_idx = torch.unique(labels,return_inverse=True)
        elif args.local_cluster == 'em':
            n_clus=args.clusters
            hidden = len(all_embs[0])
            weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                mixem.distribution.MultivariateNormalDistribution((mid+1)*(np.ones(hidden).astype(float)), np.identity(hidden)) for mid in range(n_clus)
            ])

            centroids = [torch.nan_to_num(torch.tensor(dis.mu).float()) for dis in distributions]
            all_dist = mixem.probability(np.array(torch.stack(all_embs).detach()), weights, distributions)
            norm_all_dist = (all_dist-np.min(all_dist))/(np.max(all_dist)-np.min(all_dist))
            belong_clusters = np.nonzero(norm_all_dist)

        lc_num=0
        for clus in range(args.clusters):
            if args.local_cluster == 'kmeans':
                all_idx = torch.nonzero(label_idx==clus).view(-1)
            elif args.local_cluster == 'em':
                all_idx = belong_clusters[0][np.nonzero(belong_clusters[1]==clus)[0]]

            if len(all_idx)<=1: continue
            st_edge_count={eid:0 for eid in range(int(d.edge_index.shape[1]))}
            for k in all_idx:
                _d=st_dict[id2hash[int(k)]]
                im = id2hash[int(k)]%max_graph
                st_edge_id_set = get_lhop_edge_id(_d[1], im, args.nlayer)
                for eid in st_edge_id_set: st_edge_count[eid]+=1

            max_count = max(list(st_edge_count.values()))
            st_id=[int(a) for a in st_edge_count if st_edge_count[a]>=min(max_count, 5)]
            st_edges = _d[1][:,st_id]

            st_nodes = torch.unique(st_edges).tolist()
            st_resort = {nd:ndid for ndid, nd in enumerate(st_nodes)}
            clus_x = _d[0][st_nodes]
            val_if_not_shown = 0  # 0 can be any other number within dtype range
            clus_edge_index = st_edges.cpu().apply_(lambda val: st_resort.get(val, val_if_not_shown)).to(device)

            if args.local_agg=='centroid':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([centroids[clus].cpu().detach(), local_emb])
            elif args.local_agg=='mean':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([torch.mean(torch.stack(all_embs).detach()[all_idx.tolist()], dim=0), local_emb])
            elif args.local_agg=='sum':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([local_emb, centroids[clus].cpu().detach()])

            global_st_dict_rev[local_emb]=global_id
            global_st_dict[global_id]=[clus_x, clus_edge_index, local_emb, len(all_idx), len(all_ys)-1]
            global_id+=1
            
    print(f' ... Local clustering takes {(time.time()-start)/len(all_ys)} seconds per sample')

    return all_ys, all_logits, global_st_dict, global_st_dict_rev, g_embs

def false_pred_local_clustering(loader, args, device, gnn_model):

    st_dict = {}
    global_st_dict = {}
    global_st_dict_rev = {}
    global_id = 0
    all_ys = []

    max_graph = 5000

    start = time.time()
    false_data =[]
    for i, d in enumerate(loader): 

        # if i >100:break

        d = d.to(device)
        logits = gnn_model(d)[0]
        if torch.argmax(logits) == int(d.y): continue 
        all_ys.append(int(d.y))

        false_data.append(copy.deepcopy(d))

        st_dict_rev = {}
        for im, m in enumerate(gnn_model.get_hid_repr(d,args.nlayer)):
            st_emb_m = m.cpu().detach()
            hsh = i*max_graph+im
            info = [d.x.detach(), d.edge_index.detach(), d.y.detach()]
            st_dict_rev[st_emb_m]=hsh
            st_dict[hsh]=info
        
        all_embs = list(st_dict_rev.keys())
        id2hash={j:st_dict_rev[all_embs[j]] for j in range(len(all_embs))}
        torch.set_printoptions(precision=4)
        if args.local_cluster == 'kmeans':
            kmeans = KMeans(n_clusters=args.clusters, mode='euclidean', verbose=0)
            labels, dist, centroids = kmeans.fit_predict(torch.stack(all_embs).detach())
            _,label_idx = torch.unique(labels,return_inverse=True)
        elif args.local_cluster == 'em':
            n_clus=args.clusters
            hidden = len(all_embs[0])
            weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                mixem.distribution.MultivariateNormalDistribution((mid+1)*(np.ones(hidden).astype(float)), np.identity(hidden)) for mid in range(n_clus)
            ])

            centroids = [torch.nan_to_num(torch.tensor(dis.mu).float()) for dis in distributions]
            all_dist = mixem.probability(np.array(torch.stack(all_embs).detach()), weights, distributions)
            norm_all_dist = (all_dist-np.min(all_dist))/(np.max(all_dist)-np.min(all_dist))
            belong_clusters = np.nonzero(norm_all_dist)

        lc_num=0
        for clus in range(args.clusters):
            if args.local_cluster == 'kmeans':
                all_idx = torch.nonzero(label_idx==clus).view(-1)
            elif args.local_cluster == 'em':
                all_idx = belong_clusters[0][np.nonzero(belong_clusters[1]==clus)[0]]

            if len(all_idx)<=1: continue
            st_edge_count={eid:0 for eid in range(int(d.edge_index.shape[1]))}
            for k in all_idx:
                _d=st_dict[id2hash[int(k)]]
                im = id2hash[int(k)]%max_graph
                st_edge_id_set = get_lhop_edge_id(_d[1], im, args.nlayer)
                for eid in st_edge_id_set: st_edge_count[eid]+=1

            max_count = max(list(st_edge_count.values()))
            st_id=[int(a) for a in st_edge_count if st_edge_count[a]>=min(max_count, 5)]
            st_edges = _d[1][:,st_id]

            st_nodes = torch.unique(st_edges).tolist()
            st_resort = {nd:ndid for ndid, nd in enumerate(st_nodes)}
            clus_x = _d[0][st_nodes]
            val_if_not_shown = 0  # 0 can be any other number within dtype range
            clus_edge_index = st_edges.cpu().apply_(lambda val: st_resort.get(val, val_if_not_shown)).to(device)

            if args.local_agg=='centroid':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([centroids[clus].cpu().detach(), local_emb])
            elif args.local_agg=='mean':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([torch.mean(torch.stack(all_embs).detach()[all_idx.tolist()], dim=0), local_emb])
            elif args.local_agg=='sum':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([local_emb, centroids[clus].cpu().detach()])

            global_st_dict_rev[local_emb]=global_id
            # global_st_dict[global_id]=[clus_x, clus_edge_index, local_emb, len(all_idx), len(all_ys)-1]
            global_st_dict[global_id]=[clus_x, st_id, local_emb, len(all_idx), len(all_ys)-1]
            global_id+=1
            
    if len(all_ys)>0: print(f' ... False Local clustering takes {(time.time()-start)/len(all_ys)} seconds per sample')
    return all_ys, global_st_dict, global_st_dict_rev, false_data

def get_neighbor_edge_id(edge_index, node):
    out,next_nodes=[],set()
    for i, e0 in enumerate(edge_index[0]):
        if int(e0)==int(node):
            out.append(i)
            next_nodes.add(int(edge_index[1][i]))
    return out, list(next_nodes)

def get_lhop_edge_id(edge_index, node, L): 
    nodes, all_out=[node], []
    for l in range(L):
        lnodes=[]
        for nd in nodes:
            out, next_nodes=get_neighbor_edge_id(edge_index, nd)
            all_out+=out
            lnodes+=next_nodes
        nodes = list(set(lnodes))
    return set(all_out)


def build_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='ba_2motifs')
    parser.add_argument('--gnn', type=str, default='gin')
    parser.add_argument('--local_cluster', type=str, default='kmeans', choices=['kmeans', 'em'])
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--max_nodes', type=int, default=200)
    parser.add_argument('--into_st', type=int, default=6)
    parser.add_argument('--into_class', type=int, default=0)
    parser.add_argument('--nlayer', type=int, default=3)
    parser.add_argument('--clusters', type=int, default=3)
    parser.add_argument('--lmda', type=float, default=1)

    parser.add_argument('--batch_size', type=int, default=64)

    parser.add_argument('--local_agg', type=str, default='mean', choices=['centroid', 'mean', 'sum'])
    parser.add_argument('--train_mode', type=str, default='separate', choices=['separate', 'all'])

    parser.add_argument('--plot', type=int, default=0)
    
    return parser.parse_args()
    
if __name__ == "__main__":

    import warnings
    warnings.filterwarnings("ignore")
    args = build_args()
    get_global_exps(args)
    # try: load_saved_results(args)
    # except FileNotFoundError:
    #     get_global_exps(args)
    print("done")