import argparse
import torch
from torch_geometric.loader import DataLoader
import numpy as np
from torch.utils.data import random_split

from Utils.utils import check_task, load_model, detect_exp_setting, detect_motif_nodes, show, GC_vis_graph
from Utils.metrics import efidelity
from Utils.datasets import get_dataset, get_graph_data
import torch.nn.functional as F
import copy
from torch_geometric.nn import global_mean_pool
from itertools import product
import time

from fast_pytorch_kmeans import KMeans
import mixem
from statistics import mean
from numpy import linalg as LA
import pickle

from GNN_Models.graph_sage import GraphSAGE
from st import STNet, sep_train_stnet, AugClassifier, train_aug, redetermine_pred

def get_global_exps(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataname = args.dataset
    # hidden = args.hidden
    into_st = args.into_st
    st_hop=args.nlayer

    task_type = check_task(dataname)
    dataset = get_dataset(dataname)
    try:dataset.print_summary()
    except AttributeError: pass

    try:n_fea, n_cls = dataset.num_features, dataset.num_classes 
    except AttributeError: n_fea, n_cls = dataset.num_features, 2
    explain_ids = detect_exp_setting(dataname, dataset)
    motif_nodes_number = detect_motif_nodes(dataname)
    gnn_model = load_model(dataname, args.gnn, n_fea, n_cls)
    gnn_model.eval()
    print(f"GNN Model Loaded. {dataname}, {task_type}. \nsize of Motif: {motif_nodes_number}. num of samples to explain: {len(explain_ids)}")
    print(f'Dataset={dataname}-{args.gnn}-lclus{args.clusters}-gclus{args.into_st}-{args.local_cluster}-{args.lmda}')

    gnn_modules = [mod for mod in gnn_model.children()]

    if task_type == "GC":

        num_train = int(0.8 * len(dataset))
        num_eval = int(0.1 * len(dataset))
        num_test = len(dataset) - num_train - num_eval

        train_dataset, val_dataset, test_dataset = random_split(dataset, lengths=[num_train, num_eval, num_test],
                                            generator=torch.Generator().manual_seed(42))

        train_loader = DataLoader(train_dataset, batch_size=1,  
                                shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=1,
                                shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=1,
                                shuffle=False)

        # original GNN acc:
        def _eval_acc(model, loader):
            model.eval()

            correct = 0
            length = 0
            for data in loader:
                data = data.to(device)
                
                with torch.no_grad():
                    pred = model(data).max(1)[1]
                    
                correct += pred.eq(data.y.view(-1)).sum().item()
                length += len(pred)
            return correct / length

        train_acc = _eval_acc(gnn_model, train_loader)
        val_acc = _eval_acc(gnn_model, val_loader)
        test_acc = _eval_acc(gnn_model, test_loader)
        print("+original GNN acc - train_acc:", train_acc, "- val_acc:", val_acc, "- test_acc:", test_acc)

        all_ys, _all_logits, global_st_dict, global_st_dict_rev, g_embs = local_clustering(train_loader, args, device, gnn_model, omit_wrong=False)

        path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-clus{args.clusters}-{args.local_agg}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
        # train
        kmeans_global, train_mask, clus2graph, hashedclus, keeprows = global_clustering(args, device, path, gnn_model, all_ys, global_st_dict, global_st_dict_rev)
        with open(path+'train_mask'+".pkl", 'wb') as handle:
            pickle.dump(train_mask, handle, protocol=pickle.HIGHEST_PROTOCOL)

        val_all_ys, _val_all_logits, val_global_st_dict, val_global_st_dict_rev, val_g_embs = local_clustering(val_loader, args, device, gnn_model, omit_wrong=False)
        val_mask = fit_global(args, device, gnn_model, kmeans_global, val_all_ys, val_global_st_dict, val_global_st_dict_rev, keeprows)
        with open(path+'val_mask'+".pkl", 'wb') as handle:
            pickle.dump(val_mask, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
        test_all_ys, _test_all_logits, test_global_st_dict, test_global_st_dict_rev, test_g_embs = local_clustering(test_loader, args, device, gnn_model, omit_wrong=False)
        test_mask = fit_global(args, device, gnn_model, kmeans_global, test_all_ys, test_global_st_dict, test_global_st_dict_rev, keeprows)
        with open(path+'test_mask'+".pkl", 'wb') as handle:
            pickle.dump(test_mask, handle, protocol=pickle.HIGHEST_PROTOCOL)

        f_all_ys, f_global_st_dict, f_global_st_dict_rev = false_pred_local_clustering(train_loader, args, device, gnn_model)
        if len(f_all_ys)>0:
            f_train_mask = fit_global(args, device, gnn_model, kmeans_global, f_all_ys, f_global_st_dict, f_global_st_dict_rev, keeprows)

        f_val_all_ys, f_val_global_st_dict, f_val_global_st_dict_rev = false_pred_local_clustering(val_loader, args, device, gnn_model)
        if len(f_val_all_ys)>0:
            f_val_mask = fit_global(args, device, gnn_model, kmeans_global, f_val_all_ys, f_val_global_st_dict, f_val_global_st_dict_rev, keeprows)
        
        f_test_all_ys, f_test_global_st_dict, f_test_global_st_dict_rev = false_pred_local_clustering(test_loader, args, device, gnn_model)
        if len(f_test_all_ys)>0:
            f_test_mask = fit_global(args, device, gnn_model, kmeans_global, f_test_all_ys, f_test_global_st_dict, f_test_global_st_dict_rev, keeprows)
        
        if args.train_mode=='separate': 
            
            if False:
                st_model=[]
                for tcls in range(n_cls):
                    all_ys = torch.tensor(all_ys).long().to(device)
                    clus_emb = torch.stack(list(clus2graph.keys())).to(device)
                    thres_cut = clus_emb.shape[1]//2
                    clus_emb = clus_emb[:,:thres_cut]

                    stnet = STNet(clus_emb.shape[0], gnn_modules[-2:])
                    val_all_ys=torch.tensor(val_all_ys).long().to(device)
                    test_all_ys = torch.tensor(test_all_ys).long().to(device)

                    if len(f_val_all_ys)>0 and len(f_test_all_ys)>0:
                        f_val_all_ys=torch.tensor(f_val_all_ys).long().to(device)
                        f_test_all_ys = torch.tensor(f_test_all_ys).long().to(device)

                    tcls_ys = [all_ys, val_all_ys, test_all_ys]
                    tcls_all_logits = [all_logits, val_all_logits, test_all_logits]
                    tcls_masks = [train_mask, val_mask, test_mask]
                    if len(f_all_ys)>0 and len(f_val_all_ys)>0 and len(f_test_all_ys)>0:
                        f_tcls_ys = [f_all_ys, f_val_all_ys, f_test_all_ys]
                        f_tcls_masks = [f_train_mask, f_val_mask, f_test_mask]
                    else:
                        f_tcls_ys, f_tcls_masks = None,None

                    st_weights = sep_train_stnet(stnet, clus_emb, tcls_ys, tcls_all_logits, args, tcls_masks, tcls, f_tcls_ys, f_tcls_masks).detach()[0]
                    st_model.append(copy.deepcopy(stnet))

                    save_dict = {}
                    global_concepts=[]
                    clus_embs=list(clus2graph.keys())
                    for j, weight in enumerate(st_weights):
                        emb = clus_embs[j]
                        print(f'class={tcls}, {j}: {weight}')
                        [clus_x, clus_edge_index, hash_emb] = clus2graph[emb]

                        save_dict[j]=[weight, hashedclus[hash_emb]]
                        global_concepts.append(clus2graph[emb][:2])

                        if args.plot==1:
                            GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
                            show()

                    path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-{args.local_agg}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
                    with open(path+'global_concepts'+".pkl", 'wb') as handle:
                        pickle.dump(global_concepts, handle, protocol=pickle.HIGHEST_PROTOCOL)
                    with open(path+'concepts-data'+".pkl", 'wb') as handle:
                        pickle.dump(save_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

            if True:
                st_model=[]
                for tcls in range(n_cls):
                    all_ys = torch.tensor(all_ys).long().to(device)
                    care_idx = (all_ys==tcls)
                    clus_emb = torch.stack(list(clus2graph.keys())).to(device)
                    thres_cut = clus_emb.shape[1]//2
                    clus_emb = clus_emb[:,:thres_cut]

                    stnet = STNet(clus_emb.shape[0], gnn_modules[-2:])
                    val_all_ys=torch.tensor(val_all_ys).long().to(device)
                    val_care_idx=(val_all_ys==tcls)
                    test_all_ys = torch.tensor(test_all_ys).long().to(device)
                    test_care_idx=(test_all_ys==tcls)

                    # print(all_logits.shape)

                    all_logits = torch.stack(_all_logits).to(device)
                    val_all_logits = torch.stack(_val_all_logits).to(device)
                    test_all_logits = torch.stack(_test_all_logits).to(device)

                    if len(f_all_ys)>0 and len(f_val_all_ys)>0 and len(f_test_all_ys)>0:
                        f_all_ys=torch.tensor(f_all_ys).long().to(device)
                        f_care_idx=(f_all_ys==tcls)
                        f_val_all_ys=torch.tensor(f_val_all_ys).long().to(device)
                        f_val_care_idx=(f_val_all_ys==tcls)
                        f_test_all_ys = torch.tensor(f_test_all_ys).long().to(device)
                        f_test_care_idx=(f_test_all_ys==tcls)

                    ys = [all_ys, val_all_ys, test_all_ys]
                    logits = [all_logits, val_all_logits, test_all_logits]
                    masks = [train_mask, val_mask, test_mask]

                    tcls_ys = [all_ys[care_idx], val_all_ys[val_care_idx], test_all_ys[test_care_idx]]
                    tcls_all_logits = [all_logits[care_idx], val_all_logits[val_care_idx], test_all_logits[test_care_idx]]
                    tcls_masks = [train_mask[care_idx], val_mask[val_care_idx], test_mask[test_care_idx]]
                    
                    if len(f_all_ys)>0 and len(f_val_all_ys)>0 and len(f_test_all_ys)>0:
                        f_tcls_ys = [f_all_ys[f_care_idx], f_val_all_ys[f_val_care_idx], f_test_all_ys[f_test_care_idx]]
                        f_tcls_masks = [f_train_mask[f_care_idx], f_val_mask[f_val_care_idx], f_test_mask[f_test_care_idx]]
                    else:
                        f_tcls_ys, f_tcls_masks = None,None

                    st_weights = sep_train_stnet(stnet, clus_emb, ys, logits, args, masks, tcls, f_tcls_ys, f_tcls_masks, tcls_ys=tcls_ys, tcls_orig_logits=tcls_all_logits, tcls_masks=tcls_masks).detach()[0]
                    st_model.append(copy.deepcopy(stnet))

                    save_dict = {}
                    global_concepts=[]
                    clus_embs=list(clus2graph.keys())
                    for j, weight in enumerate(st_weights):
                        emb = clus_embs[j]
                        print(f'class={tcls}, {j}: {weight}')
                        [clus_x, clus_edge_index, hash_emb] = clus2graph[emb]

                        save_dict[j]=[weight, hashedclus[hash_emb]]
                        global_concepts.append(clus2graph[emb][:2])

                        if args.plot==1:
                            GC_vis_graph(clus_x, clus_edge_index, Hedges=range(clus_edge_index.shape[1]), good_nodes=None, datasetname=dataname)
                            show()

                    path = "saved_data/"+f'st_{args.dataset}-{args.gnn}-cls{tcls}-clus{args.clusters}-{args.local_agg}-stnum{args.into_st}-{args.local_cluster}-{args.lmda}-'
                    with open(path+'global_concepts'+".pkl", 'wb') as handle:
                        pickle.dump(global_concepts, handle, protocol=pickle.HIGHEST_PROTOCOL)
                    with open(path+'concepts-data'+".pkl", 'wb') as handle:
                        pickle.dump(save_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

            aug_cls = AugClassifier(len(clus_embs[0]), n_cls)
            tcls_ys = [all_ys, val_all_ys, test_all_ys]
            tcls_masks = [train_mask, val_mask, test_mask]
            if len(f_all_ys)>0 and len(f_val_all_ys)>0 and len(f_test_all_ys)>0:
                f_tcls_ys = [f_all_ys, f_val_all_ys, f_test_all_ys]
                f_tcls_masks = [f_train_mask, f_val_mask, f_test_mask]
            else:
                f_tcls_ys, f_tcls_masks = None,None
            ori_graph_embs=[g_embs, val_g_embs, test_g_embs]
            train_aug(aug_cls,st_model,clus_emb,tcls_ys,args,tcls_masks,ori_graph_embs)
            # train_acc, val_acc, test_acc = redetermine_pred(st_model,clus_emb,tcls_ys,args,tcls_masks)
            # print(f"*Aug, train_acc: {train_acc} - val_acc: {val_acc} - test_acc: {test_acc}")



def fit_global(args, device, gnn_model, kmeans_global, all_ys, global_st_dict, global_st_dict_rev, keeprows):
    all_embs = list(global_st_dict_rev.keys())
    labels = kmeans_global.predict(torch.stack(all_embs).detach())
    _,label_idx = torch.unique(labels,return_inverse=True)
    if args.train_mode=='separate': 
        mask = torch.zeros(len(all_ys), len(set(keeprows.values()))).to(device).float()
    for clus in list(set(keeprows.keys())):
        all_idx = torch.nonzero(label_idx==clus).view(-1)
        if args.train_mode=='separate': 
            for gid in all_idx:
                dat_id = global_st_dict[int(gid)][-1]
                mask[dat_id][keeprows[clus]]=global_st_dict[int(gid)][-2]
    return mask

def global_clustering(args, device, path, gnn_model, all_ys, global_st_dict, global_st_dict_rev):

    kmeans_global = KMeans(n_clusters=args.into_st, mode='euclidean', verbose=1)
    with open(path+'kmeans_global'+".pkl", 'wb') as handle:
        pickle.dump(kmeans_global, handle, protocol=pickle.HIGHEST_PROTOCOL)

    all_embs = list(global_st_dict_rev.keys())
    labels, dist, centroids = kmeans_global.fit_predict(torch.stack(all_embs).detach())
    torch.set_printoptions(sci_mode=False)
    _,label_idx = torch.unique(labels,return_inverse=True)
    clus2graph={}
    hashedclus={}
    if args.train_mode=='separate': 
        mask = torch.zeros(len(all_ys), args.into_st).to(device).float()
        keeprows = {}
        keep=[]
        count = 0
    for clus in range(args.into_st):
        all_idx = torch.nonzero(label_idx==clus).view(-1)

        if len(all_idx)<1: continue
        closest_id = int(all_idx[int(torch.argmin(torch.abs(dist[all_idx])))])
        clusg = global_st_dict[closest_id]
        hash_emb = str(gnn_model.get_graph_emb(clusg[0],clusg[1]).detach().cpu().view(-1))
        if hash_emb in hashedclus: 
            hashedclus[hash_emb]+=[global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()]
            real_clus = list(hashedclus.keys()).index(hash_emb)
            print('old', clus, len(hashedclus[hash_emb]), dist[closest_id].cpu(), "mean dist", torch.mean(torch.abs(dist[all_idx])))
            keeprows[clus]=real_clus
            continue
        else: 
            hashedclus[hash_emb]=[global_st_dict[int(idx)][-1] for idx in all_idx.cpu().tolist()]
            print(clus, len(hashedclus[hash_emb]), dist[closest_id].cpu(), "mean dist", torch.mean(torch.abs(dist[all_idx])))
            keeprows[clus]=count
            count+=1
            keep.append(clus)

        clus2graph[centroids[clus]]=clusg[:2]+[hash_emb]
        if args.train_mode=='separate': 
            for gid in all_idx:
                dat_id = global_st_dict[int(gid)][-1]
                mask[dat_id][clus]=global_st_dict[int(gid)][-2]
    mask = mask[:,keep]
    return kmeans_global, mask, clus2graph, hashedclus, keeprows

def false_pred_local_clustering(loader, args, device, gnn_model):

    st_dict = {}
    global_st_dict = {}
    global_st_dict_rev = {}
    global_id = 0
    all_ys = []

    max_graph = 5000

    start = time.time()
    for i, d in enumerate(loader): 

        # if i >100:break

        d = d.to(device)
        logits = gnn_model(d)[0]
        if torch.argmax(logits) == int(d.y): continue 
        all_ys.append(int(d.y))

        st_dict_rev = {}
        for im, m in enumerate(gnn_model.get_hid_repr(d,args.nlayer)):
            st_emb_m = m.cpu().detach()
            hsh = i*max_graph+im
            info = [d.x.detach(), d.edge_index.detach(), d.y.detach()]
            st_dict_rev[st_emb_m]=hsh
            st_dict[hsh]=info
        
        all_embs = list(st_dict_rev.keys())
        id2hash={j:st_dict_rev[all_embs[j]] for j in range(len(all_embs))}
        torch.set_printoptions(precision=4)
        if args.local_cluster == 'kmeans':
            kmeans = KMeans(n_clusters=args.clusters, mode='euclidean', verbose=0)
            labels, dist, centroids = kmeans.fit_predict(torch.stack(all_embs).detach())
            _,label_idx = torch.unique(labels,return_inverse=True)
        elif args.local_cluster == 'em':
            n_clus=args.clusters
            hidden = len(all_embs[0])
            weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                mixem.distribution.MultivariateNormalDistribution((mid+1)*(np.ones(hidden).astype(float)), np.identity(hidden)) for mid in range(n_clus)
            ])

            centroids = [torch.nan_to_num(torch.tensor(dis.mu).float()) for dis in distributions]
            all_dist = mixem.probability(np.array(torch.stack(all_embs).detach()), weights, distributions)
            norm_all_dist = (all_dist-np.min(all_dist))/(np.max(all_dist)-np.min(all_dist))
            belong_clusters = np.nonzero(norm_all_dist)

        lc_num=0
        for clus in range(args.clusters):
            if args.local_cluster == 'kmeans':
                all_idx = torch.nonzero(label_idx==clus).view(-1)
            elif args.local_cluster == 'em':
                all_idx = belong_clusters[0][np.nonzero(belong_clusters[1]==clus)[0]]

            if len(all_idx)<=1: continue
            st_edge_count={eid:0 for eid in range(int(d.edge_index.shape[1]))}
            for k in all_idx:
                _d=st_dict[id2hash[int(k)]]
                im = id2hash[int(k)]%max_graph
                st_edge_id_set = get_lhop_edge_id(_d[1], im, args.nlayer)
                for eid in st_edge_id_set: st_edge_count[eid]+=1

            max_count = max(list(st_edge_count.values()))
            st_id=[int(a) for a in st_edge_count if st_edge_count[a]>=min(max_count, 5)]
            st_edges = _d[1][:,st_id]

            st_nodes = torch.unique(st_edges).tolist()
            st_resort = {nd:ndid for ndid, nd in enumerate(st_nodes)}
            clus_x = _d[0][st_nodes]
            val_if_not_shown = 0  # 0 can be any other number within dtype range
            clus_edge_index = st_edges.cpu().apply_(lambda val: st_resort.get(val, val_if_not_shown)).to(device)

            if args.local_agg=='centroid':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([centroids[clus].cpu().detach(), local_emb])
            elif args.local_agg=='mean':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([torch.mean(torch.stack(all_embs).detach()[all_idx.tolist()], dim=0), local_emb])
            elif args.local_agg=='sum':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([local_emb, centroids[clus].cpu().detach()])

            global_st_dict_rev[local_emb]=global_id
            global_st_dict[global_id]=[clus_x, clus_edge_index, local_emb, len(all_idx), len(all_ys)-1]
            global_id+=1
            
    if len(all_ys)>0: print(f' ... False Local clustering takes {(time.time()-start)/len(all_ys)} seconds per sample')

    return all_ys, global_st_dict, global_st_dict_rev


def local_clustering(loader, args, device, gnn_model, omit_wrong=True):

    st_dict = {}
    global_st_dict = {}
    global_st_dict_rev = {}
    global_id = 0
    all_ys, all_logits = [],[]
    g_embs=[]

    max_graph = 5000

    start = time.time()
    for i, d in enumerate(loader): 

        # if i >300:break

        d = d.to(device)
        logits = gnn_model(d)[0]
        if omit_wrong:
            if torch.argmax(logits) != int(d.y): continue 
        all_ys.append(int(d.y))
        all_logits.append(logits)
        g_embs.append(gnn_model.get_gemb(d))

        # print(f'\n{i}, y={int(d.y)}')

        st_dict_rev = {}
        for im, m in enumerate(gnn_model.get_hid_repr(d,args.nlayer)):
            st_emb_m = m.cpu().detach()
            hsh = i*max_graph+im
            info = [d.x.detach(), d.edge_index.detach(), d.y.detach()]
            st_dict_rev[st_emb_m]=hsh
            st_dict[hsh]=info
        
        all_embs = list(st_dict_rev.keys())
        id2hash={j:st_dict_rev[all_embs[j]] for j in range(len(all_embs))}
        torch.set_printoptions(precision=4)
        if args.local_cluster == 'kmeans':
            kmeans = KMeans(n_clusters=args.clusters, mode='euclidean', verbose=0)
            labels, dist, centroids = kmeans.fit_predict(torch.stack(all_embs).detach())
            _,label_idx = torch.unique(labels,return_inverse=True)
        elif args.local_cluster == 'em':
            n_clus=args.clusters
            hidden = len(all_embs[0])
            weights, distributions, ll = mixem.em(np.array(torch.stack(all_embs).detach()), [
                mixem.distribution.MultivariateNormalDistribution((mid+1)*(np.ones(hidden).astype(float)), np.identity(hidden)) for mid in range(n_clus)
            ])

            centroids = [torch.nan_to_num(torch.tensor(dis.mu).float()) for dis in distributions]
            all_dist = mixem.probability(np.array(torch.stack(all_embs).detach()), weights, distributions)
            norm_all_dist = (all_dist-np.min(all_dist))/(np.max(all_dist)-np.min(all_dist))
            belong_clusters = np.nonzero(norm_all_dist)

        lc_num=0
        for clus in range(args.clusters):
            if args.local_cluster == 'kmeans':
                all_idx = torch.nonzero(label_idx==clus).view(-1)
            elif args.local_cluster == 'em':
                all_idx = belong_clusters[0][np.nonzero(belong_clusters[1]==clus)[0]]

            if len(all_idx)<=1: continue
            st_edge_count={eid:0 for eid in range(int(d.edge_index.shape[1]))}
            for k in all_idx:
                _d=st_dict[id2hash[int(k)]]
                im = id2hash[int(k)]%max_graph
                st_edge_id_set = get_lhop_edge_id(_d[1], im, args.nlayer)
                for eid in st_edge_id_set: st_edge_count[eid]+=1

            max_count = max(list(st_edge_count.values()))
            st_id=[int(a) for a in st_edge_count if st_edge_count[a]>=min(max_count, 5)]
            st_edges = _d[1][:,st_id]

            st_nodes = torch.unique(st_edges).tolist()
            st_resort = {nd:ndid for ndid, nd in enumerate(st_nodes)}
            clus_x = _d[0][st_nodes]
            val_if_not_shown = 0  # 0 can be any other number within dtype range
            clus_edge_index = st_edges.cpu().apply_(lambda val: st_resort.get(val, val_if_not_shown)).to(device)

            if args.local_agg=='centroid':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([centroids[clus].cpu().detach(), local_emb])
            elif args.local_agg=='mean':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([torch.mean(torch.stack(all_embs).detach()[all_idx.tolist()], dim=0), local_emb])
            elif args.local_agg=='sum':
                local_emb = sum(torch.stack(all_embs).detach()[all_idx.tolist()])
                local_emb = torch.cat([local_emb, centroids[clus].cpu().detach()])

            global_st_dict_rev[local_emb]=global_id
            global_st_dict[global_id]=[clus_x, clus_edge_index, local_emb, len(all_idx), len(all_ys)-1]
            global_id+=1
            
    print(f' ... Local clustering takes {(time.time()-start)/len(all_ys)} seconds per sample')

    return all_ys, all_logits, global_st_dict, global_st_dict_rev, g_embs

def get_neighbor_edge_id(edge_index, node):
    out,next_nodes=[],set()
    for i, e0 in enumerate(edge_index[0]):
        if int(e0)==int(node):
            out.append(i)
            next_nodes.add(int(edge_index[1][i]))
    return out, list(next_nodes)

def get_lhop_edge_id(edge_index, node, L): 
    nodes, all_out=[node], []
    for l in range(L):
        lnodes=[]
        for nd in nodes:
            out, next_nodes=get_neighbor_edge_id(edge_index, nd)
            all_out+=out
            lnodes+=next_nodes
        nodes = list(set(lnodes))
    return set(all_out)

def build_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='ba_2motifs')
    parser.add_argument('--gnn', type=str, default='gin')
    parser.add_argument('--local_cluster', type=str, default='kmeans', choices=['kmeans', 'em'])
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--max_nodes', type=int, default=200)
    parser.add_argument('--into_st', type=int, default=6)
    parser.add_argument('--into_class', type=int, default=0)
    parser.add_argument('--nlayer', type=int, default=3)
    parser.add_argument('--clusters', type=int, default=3)
    parser.add_argument('--lmda', type=float, default=1)

    parser.add_argument('--batch_size', type=int, default=64)

    parser.add_argument('--local_agg', type=str, default='mean', choices=['centroid', 'mean', 'sum'])
    parser.add_argument('--train_mode', type=str, default='separate', choices=['separate', 'all'])

    parser.add_argument('--plot', type=int, default=0)
    
    return parser.parse_args()
    
if __name__ == "__main__":

    import warnings
    warnings.filterwarnings("ignore")
    args = build_args()
    get_global_exps(args)
    # try: load_saved_results(args)
    # except FileNotFoundError:
    #     get_global_exps(args)
    print("done")

