import argparse
import torch
from torch_geometric.loader import DataLoader
import numpy as np

from Utils.utils import check_task, load_model, detect_exp_setting, detect_motif_nodes, show, GC_vis_graph
from Utils.metrics import efidelity

from Utils.datasets import get_dataset, get_graph_data
import torch.nn.functional as F
import copy
from torch_geometric.nn import global_mean_pool
from itertools import product
import time

from stnet import STNet, LSTNet, subgraph_extractor, find_L_hop_edges, cluster_intersect, cluster_union

def get_global_exps(args, dataname=None, into_class=None, tminfo=None):

    if into_class is None:
        into_class = args.into_class
    st_hop = eval(args.sthop)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if dataname is None:
        dataname = args.dataset
        hidden = args.hidden
        max_st = args.max_st
        into_st = args.into_st
    else:
        if dataname=="ba_2motifs": 
            hidden = 32
            max_st = 186
            into_st = 18
        elif dataname=="bamult": 
            hidden = 20
            max_st = 317
            into_st = 32
        elif dataname=="NCI1": 
            hidden = 64
            max_st = 450
            into_st = 45
        else: 
            hidden = 64
            max_st = 680
            into_st = 68
    
    if tminfo is None:
        max_st = args.max_st
        into_st = args.into_st
        _lambda = args._lambda
    else:
        (max_st,into_st,_lambda)=tminfo

    task_type = check_task(dataname)
    dataset = get_dataset(dataname)
    try:dataset.print_summary()
    except AttributeError: pass
    n_fea, n_cls = dataset.num_features, dataset.num_classes
    explain_ids = detect_exp_setting(dataname, dataset)
    motif_nodes_number = detect_motif_nodes(dataname)
    gnn_model = load_model(dataname, args.gnn, n_fea, n_cls)
    gnn_model.eval()
    print(f"GNN Model Loaded. {dataname}, {task_type}. \nsize of Motif: {motif_nodes_number}. num of samples to explain: {len(explain_ids)}")

    gnn_modules = [mod for mod in gnn_model.children()]

    if task_type == "GC":
        loader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False)
        # loader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
        
        start = time.time()
        global_st_dict = {nl:dict() for nl in st_hop}
        
        for i, d in enumerate(loader): 
            if args.is_single_opt>0 and i<args.opt_id: continue
            if i in explain_ids:
                d = d.to(device)
                logits = gnn_model(d)[0]
                if torch.argmax(logits) != int(d.y): continue 
                
                # if into_class != int(d.y): continue

                for layer in st_hop:
                    for im, m in enumerate(gnn_model.get_hid_repr(d,layer)):
                        str_m = str(m.cpu().detach().tolist())
                        if str_m not in global_st_dict[layer]: global_st_dict[layer][str_m]=[1,[[layer,int(d.y),im,i]]]
                        else: 
                            if int(d.y)==into_class:
                                global_st_dict[layer][str_m][0]+=1
                            else: 
                                # pass
                                global_st_dict[layer][str_m][0]-=1
                                # global_st_dict[layer][str_m][0]+=1
                            global_st_dict[layer][str_m][1].append([layer,int(d.y),im,i])
                if args.is_single_opt>0 and i>=args.opt_id: break
        
        print(f'  >> number of subtrees = {len(global_st_dict[3])}')

        for ll in global_st_dict:
            for key in global_st_dict[ll]:
                global_st_dict[ll][key][0]=[abs(global_st_dict[ll][key][0]),global_st_dict[ll][key][0]]
                # global_st_dict[ll][key][0]=[global_st_dict[ll][key][0],global_st_dict[ll][key][0]]

        print("---------------------\n>> Global st_dict")

        # cv = check_coverage(loader, global_st_dict[args.nlayers], gnn_model, device)
        # return cv

        for layer in range(len(st_hop)):
            cur_layer = args.nlayers-layer
            layer_global_st_dict = global_st_dict[cur_layer]
            str_global_st_dict = sorted(layer_global_st_dict.items(),key=lambda item:item[1],reverse=True)
            st_emb = torch.tensor([eval(z[0]) for j,z in enumerate(str_global_st_dict) if j<max_st]).to(device)
            max_st = st_emb.shape[0]

            if layer==0:
                st_model = STNet(hidden, args.nclasses, args.npatterns, gnn_modules[-2:])
            else:
                hconvs = [conv for conv in gnn_modules[1].children()]
                hconvs = [hconvs[-1*layer]]
                st_model = LSTNet(max_st, hconvs)
            st_model.cuda()
            optimizer = torch.optim.Adam(st_model.parameters(), lr=args.lr)

            best_loss = 9999
            for epoch in range(1, args.epochs):
                if layer==0:
                    loss= train_st(_lambda, st_model, optimizer, st_emb, y=torch.LongTensor([into_class]).to(device))
                else:
                    loss= ltrain(st_model, optimizer, st_emb, y=hst_emb.to(device))
                if loss < best_loss:
                    # print("Epoch",epoch, "- loss:", loss)
                    best_loss = loss
                    torch.save(st_model.state_dict(), "saved_models/"+f'st_{cur_layer}_{args.dataset}-{args.gnn}'+".model")
                    # early stop
                    if loss<5e-3: break
            st_model.load_state_dict(torch.load("saved_models/"+f'st_{cur_layer}_{args.dataset}-{args.gnn}'+".model"))
            st_model.eval()
            
            # norm_weights = F.normalize(st_model.weight, p=1.0, dim=-1)[0]
            # norm_weights = F.normalize(st_model.weight, p=args.npatterns, dim=-1)[0]
            # ***************************************
            # norm_weights = st_model.weight[0]/max(st_model.weight[0])
            # sorted_norm_weights = norm_weights.sort(dim=-1, descending=True)[0]
            # into_st_indices = torch.nonzero(norm_weights>=sorted_norm_weights[into_st]).view(-1)
            # ***************************************
            norm_weights = st_model.weight[0]/max(torch.abs(st_model.weight[0]))
            sorted_norm_weights = torch.abs(norm_weights).sort(dim=-1, descending=True)[0]
            into_st_indices = torch.nonzero(torch.abs(norm_weights)>=sorted_norm_weights[into_st]).view(-1)
            topk_st_indices = torch.topk(norm_weights[into_st_indices], len(into_st_indices), dim=-1)[1].cpu().detach().numpy()

            elap_time = time.time()-start
            print(f"Elapsed Time: {elap_time}s")
            print(f"best loss={best_loss}")

            signif_st = [str_global_st_dict[j] for j in into_st_indices[topk_st_indices]]
            signif_st_embs = torch.stack([torch.tensor(eval(st[0])) for st in signif_st]).to(device)
            st_instances = [st[1][1:] for st in signif_st]
            st_importance = norm_weights[into_st_indices[topk_st_indices]]
            
            # subgraph_extractor(into_st_indices[topk_st_indices], norm_weights, str_global_st_dict, loader, dataname=dataname, args=args, gnn_model=gnn_model,device=device, into_class=into_class)
            results = cluster_union(into_st_indices[topk_st_indices], norm_weights, str_global_st_dict, loader, dataname=dataname, args=args, gnn_model=gnn_model,device=device, into_class=into_class)
            
            if layer==0:
                # hst_emb = torch.mm(F.normalize(st_model.weight, p=1.0, dim=-1), st_emb).detach().clone()
                # hst_emb = torch.mm(F.normalize(st_model.weight, p=args.npatterns, dim=-1), st_emb).detach().clone()
                hst_emb = torch.mm(st_model.weight, st_emb).detach().clone()
            else:
                hst_emb = torch.mm(st_model.weight, st_emb).detach().clone()
            del st_model
    
    return elap_time, results, signif_st_embs, st_importance, st_instances, loader, gnn_model

def check_coverage(loader, layer_global_st_dict, gnn_model, device):
    str_global_st_dict = sorted(layer_global_st_dict.items(),key=lambda item:item[1],reverse=True)

    topT=[0.01, 0.2, 0.6, 1, 2, 3,4,5,6,7,8,9,10]
    coverage = []
    allnum = len(str_global_st_dict)
    print(allnum)

    for tt in topT:
        all_count = 0
        hit=0
        max_st=tt*allnum*0.01
        st_emb = torch.tensor([eval(z[0]) for j,z in enumerate(str_global_st_dict) if j<max_st]).to(device)
        for i, d in enumerate(loader):
            d = d.to(device)
            logits = gnn_model(d)[0]
            if torch.argmax(logits) != int(d.y): continue 
            all_count+=1
            for ndemb in gnn_model.get_hid_repr(d,3):
                if ndemb in st_emb:
                    hit+=1
                    break
        coverage.append(hit/all_count)
        print(coverage)


def train_st(_lambda, model, optimizer, data, y):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    # not_y=torch.LongTensor([(int(y)+1)%2]).to(y.device)
    # loss = F.nll_loss(out, y.view(-1).long())-F.nll_loss(out, not_y.view(-1).long())
    # loss = F.nll_loss(out, y.view(-1).long())
    # loss = 5*F.nll_loss(out, y.view(-1).long())+torch.norm(model.weight, p=2.0)
    target = torch.zeros(out.shape).to(y.device).float()
    target[0,int(y)]=1
    # nontarget = torch.zeros(out.shape).to(y.device).float()
    # nontarget[0,int(not_y)]=1
    # loss = F.cross_entropy(out, target)+args._lambda*torch.norm(model.weight, p=2.0)-F.cross_entropy(out, nontarget)
    loss = F.cross_entropy(out, target)+_lambda*torch.norm(model.weight, p=2.0)
    # print(out, int(y), loss, torch.norm(model.weight, p=2.0))
    loss.backward()
    optimizer.step()
    return loss.item() 

def ltrain(model, optimizer, data, y):
    model.train()
    torch.autograd.set_detect_anomaly(True)
    optimizer.zero_grad()
    out = model(data)
    get_loss = torch.nn.MSELoss()
    loss = get_loss(out, y)
    loss.backward(retain_graph=True)
    optimizer.step()
    return loss.item() 

def build_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='ba_2motifs')
    parser.add_argument('--gnn', type=str, default='gin')
    parser.add_argument('--lr', type=float, default=0.002)
    parser.add_argument('--epochs', type=int, default=2000)
    parser.add_argument('--max_st', type=int, default=800)
    parser.add_argument('--into_st', type=int, default=60)
    parser.add_argument('--into_class', type=int, default=0)
    parser.add_argument('--sthop', type=str, default='[3]')
    parser.add_argument('--nlayers', type=int, default=3)
    parser.add_argument('--nclasses', type=int, default=2)
    parser.add_argument('--npatterns', type=float, default=2)
    parser.add_argument('--hidden', type=int, default=64)

    parser.add_argument('--h_thres', type=float, default=0.000)
    parser.add_argument('--l_thres', type=float, default=-0.01)
    parser.add_argument('--_lambda', type=float, default=0.33)
    parser.add_argument('--clusters', type=int, default=20)

    parser.add_argument('--is_single_opt', type=int, default=0)
    parser.add_argument('--opt_id', type=int, default=0)

    parser.add_argument('--do_plot', type=int, default=0)
    parser.add_argument('--plot_layer', type=int, default=3)

    # parser.add_argument('--do_evaluate', type=int, default=1)
    
    return parser.parse_args()

if __name__ == "__main__":

    args = build_args()
    get_global_exps(args)
    print("done")

