import sys
import os
curPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(curPath)[0]
sys.path.append(rootPath)
import numpy
import numpy as np
import torch
from dgl.dataloading import DataLoader
from dgl.dataloading import NeighborSampler
from module import GTC
import datetime
import random
from self_tools.data_tools import load_data, get_batch_pos
from self_tools.evaluate import evaluate_for_test, evaluate_for_train
from self_tools.params import set_params
# Import metrics for NMI and ARI
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score

args = set_params()
if torch.cuda.is_available() and args.device > -1:
    device = torch.device("cuda:0")
    torch.cuda.set_device(args.device)
else:
    device = torch.device("cpu")

## name of intermediate document ##

own_str = args.dataset
exp_num = 1


def make(config, dgl_graph, feats_dim_list, P, h_dict, category, all_node_idx,
         num_classes, mini_batch_flag=True):
    """
    the fuction of building the model, train_loader and optimizer
    :param config:
    :param dgl_graph:
    :param feats_dim_list:
    :param P:
    :param meta_path_adj:
    :param h_dict:
    :param category:
    :param all_node_idx:
    :param num_classes:
    :param mini_batch_flag:
    :return: model，train_loader,optimizer
    """
    print("seed ", config.seed)
    print("Dataset: ", config.dataset)
    print("The number of gnn_branch_num: ", config.gnn_branch_layer_num)
    # build the GTC model
    model = GTC(config.hidden_dim, feats_dim_list, config.feat_drop, P, config.tau, config.lam,
                t_hops=config.t_hops, t_n_class=num_classes, t_input_dim=h_dict[category].shape[1],
                t_pe_dim=config.t_pe_dim, t_n_layers=config.t_n_layers, t_num_heads=config.t_n_heads,
                t_dropout_rate=config.t_dropout,
                t_attention_dropout_rate=config.t_attention_dropout, rel_names=dgl_graph.etypes, category=category,
                gnn_branch_layer_num=config.gnn_branch_layer_num)
    # build the optimizer for GTC
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.l2_coef)

    # NeighborSampler and corresponding graph DataLoader for mini_batch training~
    # for more details for NeighborSampler and DataLoader, please see https://docs.dgl.ai/guide/minibatch.html#guide-minibatch
    fanouts = [20]  # first hop sample 20 neighbors for every node
    for i in range(1, config.gnn_branch_layer_num):
        fanouts.append(10)  # 2-gnn_branch_layer_num hop sample 10 neighbors for every node
    sampler = NeighborSampler(fanouts=fanouts)
    all_idx_dict = {category: all_node_idx}

    train_dataloader_4GTC = DataLoader(graph=dgl_graph, indices=all_idx_dict, graph_sampler=sampler,
                                       batch_size=128,
                                       shuffle=True)

    return model, train_dataloader_4GTC, optimizer


def train_flow(model, train_loader, optimizer, config, category, pos, own_str, exp=0):
    cnt_wait = 0
    best = 1e9
    best_t = 0
    print('-' * 60)
    print('train_flow for exp-{}'.format(exp))
    starttime = datetime.datetime.now()
    # 
    for epoch in range(config.nb_epochs):
        model.train()
        loss_epoch = 0
        for batch_id, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            blocks = [block.to(config.device) for block in blocks]
            # for GNN_branch batch data
            if 'h' in blocks[0].srcdata:
                input_fea4GNN = blocks[0].srcdata['h']
            elif 'feature' in blocks[0].srcdata:
                input_fea4GNN = blocks[0].srcdata['feature']
            else:
                print('please specify the feature key!')
                return
            if not isinstance(input_fea4GNN, dict):
                input_fea4GNN = {category: input_fea4GNN}
            # deal with pos for mini-batch
            pos_batch = get_batch_pos(pos=pos, batch_node_id_x=output_nodes[category].numpy()).to(config.device)
            # [num_meta-paths,num_nodes,num_hops,feature_dim}
            multi_hop_features = blocks[-1].dstnodes[category].data['multi_hop_feature'].permute(1, 0, 2, 3)

            loss = model(g=blocks, feats=input_fea4GNN, multi_hop_features=multi_hop_features, pos=pos_batch,
                         mini_batch_flag=True)
            loss_epoch = loss_epoch + loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print("exp={}; epoch: {};batch-{}; loss {}".format(exp, epoch, batch_id, loss.data.cpu()))

        print(" epoch: {}; epoch_loss {}".format(epoch, loss_epoch.data.cpu()))
        if loss_epoch < best:
            print('best loss: {}->{}'.format(best, loss_epoch))
            best = loss_epoch
            best_t = epoch
            cnt_wait = 0
            # save better checkpoint~
            torch.save(model.state_dict(), '/root/autodl-tmp/data/GTC_' + own_str + '.pkl')
        else:
            cnt_wait += 1
            print('lost not improved~ {}'.format(cnt_wait))
        if cnt_wait >= config.patience:
            print('Early stopping at {} epoch!'.format(epoch))
            break
    print('best epoch is {} !'.format(best_t))
    endtime = datetime.datetime.now()
    time = (endtime - starttime).seconds
    print('Total train time {} s'.format(time))
    print('-' * 40)
    return best_t


def evaluate_clustering_metrics(emb, labels, train_idx):
    """
    Evaluate clustering metrics (NMI and ARI) using the learned embeddings
    
    :param emb: Node embeddings
    :param labels: True labels
    :param train_idx: Indices of training nodes
    :return: NMI and ARI scores
    """
    from sklearn.cluster import KMeans
    import numpy as np
    
    # Get embeddings and labels for training nodes
    emb_np = emb[train_idx].cpu().numpy()
    
    # Get true labels and ensure they are 1D
    # If labels are one-hot encoded, convert to class indices
    true_labels = labels[train_idx].cpu().numpy()
    if len(true_labels.shape) > 1 and true_labels.shape[1] > 1:
        # Convert one-hot encoded labels to class indices
        true_labels = np.argmax(true_labels, axis=1)
    
    # Get the number of classes
    n_clusters = len(np.unique(true_labels))
    
    # Apply KMeans clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=0)
    cluster_labels = kmeans.fit_predict(emb_np)
    
    # Calculate NMI and ARI
    nmi = normalized_mutual_info_score(true_labels, cluster_labels)
    ari = adjusted_rand_score(true_labels, cluster_labels)
    
    return nmi, ari


def test(model, config, train_idx_list, val_idx_list, test_idx_list, labels, num_classes, fea_evalue, 
         ma_dic_list, mi_dic_list, auc_dic_list, nmi_dic_list, ari_dic_list):
    starttime = datetime.datetime.now()
    model.eval()
    with torch.no_grad():  # Disable gradient computation
        emb = model.get_embeds(multi_hop_features=fea_evalue.permute(1, 0, 2, 3), batch_size=1000)
    for i in range(len(train_idx_list)):
        ma, mi, auc = evaluate_for_train(config.hidden_dim, train_idx_list[i], val_idx_list[i], test_idx_list[i],
                                         labels, num_classes, config.device, config.dataset, config.eva_lr,
                                         config.eva_wd, batch_size=128, patience=config.patience, emb=emb)
        
        # Calculate NMI and ARI for each training split
        nmi, ari = evaluate_clustering_metrics(emb, labels, train_idx_list[i])
        
        ma_dic_list['ma_{}'.format(config.ratio[i])].append(ma)
        mi_dic_list['mi_{}'.format(config.ratio[i])].append(mi)
        auc_dic_list['auc_{}'.format(config.ratio[i])].append(auc)
        nmi_dic_list['nmi_{}'.format(config.ratio[i])].append(nmi)
        ari_dic_list['ari_{}'.format(config.ratio[i])].append(ari)
        
        # Print clustering metrics
        print('Split {} - NMI: {:.4f}, ARI: {:.4f}'.format(config.ratio[i], nmi, ari))
        
    endtime = datetime.datetime.now()
    time = (endtime - starttime).seconds
    print("Total evaluate time: ", time, "s")
    print('-' * 40)


def model_train(args):
    # record the result of each exp
    ma_dic_list = dict.fromkeys(['ma_20', 'ma_40', 'ma_60'])
    for key in ma_dic_list.keys():
        ma_dic_list[key] = []
    mi_dic_list = dict.fromkeys(['mi_20', 'mi_40', 'mi_60'])
    for key in mi_dic_list.keys():
        mi_dic_list[key] = []
    auc_dic_list = dict.fromkeys(['auc_20', 'auc_40', 'auc_60'])
    for key in auc_dic_list.keys():
        auc_dic_list[key] = []
    
    # Add NMI and ARI dictionaries
    nmi_dic_list = dict.fromkeys(['nmi_20', 'nmi_40', 'nmi_60'])
    for key in nmi_dic_list.keys():
        nmi_dic_list[key] = []
    ari_dic_list = dict.fromkeys(['ari_20', 'ari_40', 'ari_60'])
    for key in ari_dic_list.keys():
        ari_dic_list[key] = []
        
    for exp in range(exp_num):  # every exp
        print('-' * 60)
        print('exp:{}'.format(exp))
        print('-' * 60)
        starttime = datetime.datetime.now()
        if torch.cuda.is_available() and args.device > -1:
            device = torch.device("cuda:0")
            torch.cuda.set_device(args.device)
        else:
            device = torch.device("cpu")

        # name of intermediate document
        own_str = args.dataset + '_' + str(exp)

        # random seed
        seed = args.seed
        numpy.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        # load data~
        dgl_graph, category, all_node_idx, train_idx_list, val_idx_list, test_idx_list, \
        h_dict, labels, P, num_classes, pos = load_data(
            data_name=args.dataset, data_dir='/root/autodl-tmp/data/', t_hops=args.t_hops,
            cache_sub_dir='cache-opensource')

        feats_dim_list = [h_dict[key].shape[-1] for key in h_dict.keys()]

        # build the model, train_loader and optimizer
        model, train_loader, optimizer = make(args, dgl_graph, feats_dim_list, P, h_dict,
                                              category, all_node_idx, dgl_graph.etypes, num_classes)
        print(model)

        if torch.cuda.is_available() and args.device > -1:
            print('Using CUDA~')
            model.to(device)
            labels = labels.cuda()
            for index in range(len(train_idx_list)):
                train_idx_list[index] = train_idx_list[index].long().cuda()
                val_idx_list[index] = val_idx_list[index].long().cuda()
                test_idx_list[index] = test_idx_list[index].long().cuda()

        # train the model~
        best_t = train_flow(model, train_loader, optimizer, args, category, pos, own_str, exp=exp)
        # test the model~
        print('-' * 40)
        print('test paradigm~')
        print('Loading {}th epoch'.format(best_t))
        # load checkpoint
        model.load_state_dict(torch.load('/root/autodl-tmp/data/GTC_' + own_str + '.pkl'))
        fea_evalue = dgl_graph.nodes[category].data['multi_hop_feature'].to(device)
        # test flow
        test(model, args, train_idx_list, val_idx_list, test_idx_list, labels, num_classes, fea_evalue, 
             ma_dic_list, mi_dic_list, auc_dic_list, nmi_dic_list, ari_dic_list)

        endtime = datetime.datetime.now()
        time = (endtime - starttime).seconds
        print("Total time: ", time, "s")

        # print the result
        for key in ma_dic_list.keys():
            lst = ma_dic_list[key]
            print('{}_mean:{},{}_var:{}'.format(key, np.mean(lst), key, np.std(lst)))
            # print('{}:{}'.format(key, lst))

        for key in mi_dic_list.keys():
            lst = mi_dic_list[key]
            print('{}_mean:{},{}_var:{}'.format(key, np.mean(lst), key, np.std(lst)))
            # print('{}:{}'.format(key, lst))

        for key in auc_dic_list.keys():
            lst = auc_dic_list[key]
            print('{}_mean:{},{}_var:{}'.format(key, np.mean(lst), key, np.std(lst)))
            # print('{}:{}'.format(key, lst))
        
        # Print NMI and ARI results
        for key in nmi_dic_list.keys():
            lst = nmi_dic_list[key]
            print('{}_mean:{},{}_var:{}'.format(key, np.mean(lst), key, np.std(lst)))
            # print('{}:{}'.format(key, lst))
            
        for key in ari_dic_list.keys():
            lst = ari_dic_list[key]
            print('{}_mean:{},{}_var:{}'.format(key, np.mean(lst), key, np.std(lst)))
            # print('{}:{}'.format(key, lst))


def test_pre_trained_model(args):
    model = torch.load('/root/autodl-tmp/data/{}_model.pkl'.format(args.dataset))
    ## random seed ##
    seed = model.seed
    numpy.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # load data
    dgl_graph, category, all_node_idx, train_idx_list, val_idx_list, test_idx_list, \
    h_dict, labels, P, num_classes, pos = load_data(
        data_name=args.dataset, data_dir='/root/autodl-tmp/data/', t_hops=model.t_hops,
        cache_sub_dir='cache-opensource')
    if torch.cuda.is_available() and args.device > -1:
        print('Using CUDA')
        model.to(device)
        labels = labels.cuda()
        for index in range(len(train_idx_list)):
            train_idx_list[index] = train_idx_list[index].long().cuda()
            val_idx_list[index] = val_idx_list[index].long().cuda()
            test_idx_list[index] = test_idx_list[index].long().cuda()

    starttime = datetime.datetime.now()
    model.eval()
    fea_evalue = dgl_graph.nodes[category].data['multi_hop_feature'].to(device)
    
    # Initialize dictionaries for clustering metrics
    nmi_dic_list = dict.fromkeys(['nmi_20', 'nmi_40', 'nmi_60'])
    for key in nmi_dic_list.keys():
        nmi_dic_list[key] = []
    ari_dic_list = dict.fromkeys(['ari_20', 'ari_40', 'ari_60'])
    for key in ari_dic_list.keys():
        ari_dic_list[key] = []
    
    # Get embeddings for all nodes
    with torch.no_grad():
        emb = model.get_embeds(multi_hop_features=fea_evalue.permute(1, 0, 2, 3), batch_size=1000)
    
    for i in range(len(train_idx_list)):
        # Evaluate for test using the existing function
        evaluate_for_test(model.hidden_dim, train_idx_list[i], val_idx_list[i], test_idx_list[i],
                          labels,
                          num_classes, device,
                          args.dataset,
                          args.eva_lr, args.eva_wd, model=model, fea_evalue=fea_evalue,
                          patience=args.patience, batch_size=128)
        
        # Calculate and print NMI and ARI scores
        nmi, ari = evaluate_clustering_metrics(emb, labels, train_idx_list[i])
        print('Split {} - NMI: {:.4f}, ARI: {:.4f}'.format(args.ratio[i], nmi, ari))
        
    endtime = datetime.datetime.now()
    time = (endtime - starttime).seconds
    print("Total time: ", time, "s")


if __name__ == '__main__':
    model_train(args)