import argparse
import time
import shutil
import torch 
import os
import numpy as np
import random

import torch.utils
import torch.utils.data
from modules.utils import MLP


from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from model_mutiGIN import GraphModel
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR


from sklearn.linear_model import LogisticRegression

from process_data.utils import dataset_info
from process_data.graph_dataset import TransTUDataset
from process_data.split import split_tu_dataset
from modules.utils import check_nan_inf
from process_data.align import align_feat_downstream, align_single

from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import SpectralClustering
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# from zero_shot import batch_zero_shot

TEST = [
        'ENZYMES',
        "NCI1", 
        'NCI109', 
        'DD', 
        'Mutagenicity',        
    ]
def seed_all(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) # prohibit hash randomization and make the experiment reproducible
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def classify(Z_train, Z_test, y_train, y_test, test_dataset, fewshot_set, device):

    train_set = torch.utils.data.TensorDataset(Z_train, y_train)
    test_set = torch.utils.data.TensorDataset(Z_test, y_test)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size)
    test_loader = torch.utils.data.DataLoader(test_set, args.batch_size)

    classifier = MLP(Z_train.shape[-1], len(torch.unique(y_train)), num_layers=1).to(device)
    optim = torch.optim.Adam(classifier.parameters(), lr=fewshot_set['lr'])
    criterion = nn.CrossEntropyLoss(reduction='mean')
    retained_acc = 0.
    
    for epoch in range(fewshot_set['epoch']):
        cnt = 0.
        losses = 0.
        for X, y in train_loader:
            classifier.train()
            X = X.to(device)
            y = y.to(device)
            logits = classifier(X)
            loss = criterion(logits, y)
            pred = torch.argmax(logits.detach().cpu(), dim=-1)
            cnt+=torch.sum(pred==y.detach().cpu()).item()
            losses += loss.item()
            optim.zero_grad()
            loss.backward()
            optim.step()
        train_acc = cnt / len(Z_train)
        train_loss = losses / len(train_loader)
        # print(f'Downstream Dataset: {test_dataset}, Epoch: {epoch + 1}, train loss: {train_loss:.4f}, train acc:{train_acc:.4f}')

        classifier.eval()
        with torch.no_grad():
            cnt = 0.
            test_losses = []
            for X, y in test_loader:
                X = X.to(device)
                y = y.to(device)
                classifier.to(device)
                test_logits = classifier(X)
                test_losses.append(criterion(test_logits, y).item())
                pred = torch.argmax(test_logits, dim=-1)
                cnt += torch.sum(pred==y).item()
            test_loss = np.mean(test_losses)
            test_acc = cnt / len(Z_test)
            if test_acc > retained_acc:
                retained_acc = test_acc
            # print(f'{test_dataset} train loss: {train_loss:.4f}, train acc:{train_acc:.4f}, test loss: {test_loss:.4f},Test acc: {test_acc:.4f}, retained acc: {retained_acc:.4f}') 
    return retained_acc

def cluster(Z, y, n_trials, gamma_scale=8.):
    nmi = []
    ari = []
    acc = []
    n_class = len(torch.unique(y))
    D = torch.cdist(Z, Z, p=2)
    mean_dist = D.mean().item()
    D = D.numpy()
    
    Z = Z.numpy()
    y = y.numpy()
    N=len(Z) 
    gamma=1./(mean_dist)**2 / gamma_scale
    K=np.exp(-gamma*D**2)
    # print(K)
    for i in range(n_trials):
        sc = SpectralClustering(n_clusters=n_class,affinity='precomputed')
        pred = sc.fit_predict(K, y)

        # sc = SpectralClustering(n_clusters=n_class,affinity='nearest_neighbors', n_neighbors=5)
        # pred = sc.fit_predict(Z, y)
        nmi.append(normalized_mutual_info_score(y, pred))
        ari.append(adjusted_rand_score(y, pred))
        acc.append(acc_hungarian(y, pred))
    acc_mean, acc_std = float(np.mean(acc)), float(np.std(acc))
    nmi_mean, nmi_std = float(np.mean(nmi)), float(np.std(nmi))
    ari_mean, ari_std = float(np.mean(ari)), float(np.std(ari))
    print(f'Acc: mean {acc_mean}, std {acc_std}')
    print(f'NMI: mean {nmi_mean}, std {nmi_std}')
    print(f'ARI: mean {ari_mean}, std {ari_std}')
    return {
        'acc_mean': acc_mean,
        'acc_std': acc_std,
        'nmi_mean': nmi_mean,
        'nmi_std': nmi_std,
        'ari_mean': ari_mean,
        'ari_std': ari_std,
    }


def acc_hungarian(y_true, y_pred):
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    ind = linear_sum_assignment(w.max() - w)
    ind = np.asarray(ind)
    ind = np.transpose(ind)
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

def get_feature(name, root, test_name, dim, scales, load_path, device):
    fs_train, fs_test = align_feat_downstream(
        name, 
        root, 
        dim, 
        scales=scales, 
        device=device, 
        save_path=load_path, 
        test_name=test_name
    )
    train_rep, test_rep = [], []
    for data in fs_train:
        train_rep.append(data.x.mean(dim=0))
    for data in fs_test:
        test_rep.append(data.x.mean(dim=0))
    train_rep = torch.cat(train_rep, dim=0)
    test_rep = torch.cat(test_rep, dim=0)
    return train_rep, test_rep

class Trainer():
    def __init__(
        self, 
        test_dataset,
        batchsize, 
        device,     
        hyparam_model: dict = ..., 
        fewshot_param: dict = ..., 
        jumping_mode: str = None,   

        root: str = './TUDataset', 
        pretrain_path: str = './ckpt', 
        model_name = None, 
        dim=10,
        scales = [0.25, 0.5, 1., 2., 5.], 
        gamma_scale = 0.01, 
        seed=1145, 
        align_feat=False, 
        inductive=True
    ):

        self.test_dataset = test_dataset
        self.batchsize = batchsize
        self.device = device
        self.hyparam_model = hyparam_model

        self.jumping_mode = jumping_mode

        self.k_shot = fewshot_param['k']
        self.fewshot_epochs = fewshot_param['epoch']
        self.fewshot_lr = fewshot_param['lr']

        self.root = root
        self.load_path = pretrain_path
        self.dim=dim
        self.scales=scales
        self.gamma_scale=gamma_scale

        self.seed = seed
        self.align = align_feat
        self.inductive = inductive
        if model_name is None:
            self.model_name = test_dataset
        else:
            self.model_name = model_name

    def config_model(self, hp_param):
        gt_param = hp_param["gt_param"]
        self.model = GraphModel(
            hp_param["num_atoms"],
            hp_param["num_atom_supp"],
            hp_param["gamma"],
            use_mlp_head=False, 
            mlp_out_dim=None, 
            mlp_num_layers=hp_param['mlp_layers'], 
            jumping_mode=self.jumping_mode, 
            readout=hp_param["readout"], 
            n_graph=len(self.scales), 
            feat_dim=self.dim, 
            gin_num_layer=3, 
            gin_hidden_dim=hp_param['gin_hidden_dim'], 
            **gt_param
        )
        self.model.to(self.device)

    
    def config_optimizer(self):

        optimizer = torch.optim.Adam(
            params=self.model.mlp_head.parameters(), 
            lr=self.fewshot_lr
        )
        self.optim = optimizer
    
    
    def set_test_dataset(self, name: str):

        train_path, test_path = split_tu_dataset(name, self.root, 
                                                 mode='fewshot', force_reload=True, 
                                                 k_shot=self.k_shot, 
                                                 seed=self.seed)
        self.train_dataset = TransTUDataset(
            root=train_path,
            name=name,
            mode='train',
            use_decomp='all_graphs',
            dim=args.dim, 
            force_reload=False
        )

        self.num_class = len(torch.unique(self.train_dataset.y))   
        self.test_dataset = TransTUDataset(
            root=test_path,
            name=name,
            mode='test',
            use_decomp='all_graphs',
            dim=args.dim, 
            force_reload=False, 
            align_feat=True
        )
        train_set = []
        test_set = []
        if name == 'DD':
            for data in self.train_dataset:
                if data.x.shape[0] < 1000:
                    train_set.append(data)
            self.train_dataset = train_set
            for data in self.test_dataset:
                if data.x.shape[0] < 1000:
                    test_set.append(data)
            self.test_dataset = test_set
        
        fewshot_loader = DataLoader(self.train_dataset, self.batchsize, shuffle=False)
        test_loader = DataLoader(self.test_dataset, self.batchsize, shuffle=False)

        return fewshot_loader, test_loader


    def set_inductive_dataset(self, name):
        if self.align:
            train, test = align_feat_downstream(
                name, 
                self.root, 
                self.dim, 
                scales=self.scales, 
                device=self.device, 
                save_path=os.path.dirname(self.load_path), 
                test_name=self.model_name, 
                gamma_scale=self.gamma_scale, 
                k_shot=self.k_shot
            )
        else:
            train_path, test_path = split_tu_dataset(name, self.root, 
                                                    mode='fewshot', force_reload=True, 
                                                    k_shot=self.k_shot, 
                                                    seed=self.seed)
            train = TransTUDataset(
                root=train_path,
                name=name,
                mode='train',
                use_decomp='all_graphs',
                dim=args.dim, 
                force_reload=False
            )

            test = TransTUDataset(
                root=test_path,
                name=name,
                mode='test',
                use_decomp='all_graphs',
                dim=args.dim, 
                force_reload=False, 
                align_feat=True
            )

        if name == 'DD':
            train_set = []
            for data in train:
                if data.x.shape[0] < 1000:
                    train_set.append(data)
            test_set = []
            for data in test:
                if data.x.shape[0] < 1000:
                    test_set.append(data)
        else:
            train_set = train
            test_set = test
        self.num_class = len(set([data.y for data in train_set]))
        fewshot_loader = DataLoader(train_set, self.batchsize, shuffle=True)
        test_loader = DataLoader(test_set, self.batchsize, shuffle=False)    
        return fewshot_loader, test_loader    
        
    def set_transductive_dataset(self, name: str):
        if self.align:
            test_dataset = align_single(
                name, 
                self.root, 
                self.dim, 
                scales=self.scales, 
                device=self.device, 
                save_path=os.path.dirname(self.load_path), 
                test_name=self.model_name, 
                force_reload=True
            )
        else:
            test_dataset = TransTUDataset(
                self.root, 
                self.test_dataset, 
                dim=self.dim, 
                scales=self.scales,
                mode=None, 
                force_reload=True
            )
            
        train_indices, test_indices = [], []
        y = torch.cat([data.y for data in test_dataset])
        unique_labels = torch.unique(y)
        self.num_class = len(unique_labels)

        for label in unique_labels:
            indices = torch.where(y == label)[0]
            if len(indices) <= self.k_shot:
                raise ValueError(f"Label {label} has less than {self.k_shot} samples")
            indices = torch.tensor(np.random.permutation(indices.numpy()))
            train_indices.append(indices[:self.k_shot])
            test_indices.append(indices[self.k_shot:])

        train_indices = torch.cat(train_indices, dim=0)
        test_indices = torch.cat(test_indices, dim=0)

        train_dataset = [test_dataset[ind.item()] for ind in train_indices.long()]
        test_dataset = [test_dataset[ind.item()] for ind in test_indices.long()]


        if name == 'DD': 
            train_set = []
            test_set = []
            for data in train_dataset:
                if data.x.shape[0] < 1000:
                    train_set.append(data)
            train_dataset = train_set

            for data in test_dataset:
                if data.x.shape[0] < 1000:
                    test_set.append(data)
            test_dataset = test_set

        fewshot_loader = DataLoader(train_dataset, self.batchsize, shuffle=True)
        test_loader = DataLoader(test_dataset, self.batchsize, shuffle=False)
        return fewshot_loader, test_loader
        

    def eval(self):
        fewshot_loader, test_loader = self.set_inductive_dataset(self.test_dataset)

        self.config_model(self.hyparam_model)
        self.config_optimizer()
        self.model.train()

        # load model
        self.model.load_state_dict(torch.load(self.load_path, weights_only=True, map_location=self.device), strict=False)

        train_accs = []
        test_accs = []
        retained_acc = 0
        for epoch in range(self.fewshot_epochs):
            train_loss, train_acc = self.train_epoch(fewshot_loader)
            print(f'Downstream Dataset: , Epoch: {epoch + 1}, train loss: {train_loss:.4f}, train acc:{train_acc:.4f}')
            
            if epoch == 0 or (epoch+1) % 2 == 0:
                test_loss, test_acc = self.valid(test_loader)
                if test_acc > retained_acc:
                    retained_acc = test_acc
                print(f' test loss: {test_loss:.4f},Test acc: {test_acc:.4f}, retained acc: {retained_acc:.4f}')

            train_accs.append(train_acc)
            test_accs.append(test_acc)
        self.retained_acc = retained_acc
        # plot_acc(self.fewshot_epochs, self.test_dataset, train_acc=train_accs, test_acc=test_accs)

    def few_shot_feed_forward(self):
        if not self.inductive:
            train_loader, test_loader = self.set_transductive_dataset(self.test_dataset)
        else:
            train_loader, test_loader = self.set_inductive_dataset(self.test_dataset)
        
        self.config_model(self.hyparam_model)
        print(f"loading model from {self.load_path}")
        self.model.load_state_dict(torch.load(self.load_path, weights_only=True, map_location=self.device), strict=False)
        self.model.eval()
        with torch.no_grad():
            train_features, train_y = [], []
            test_features, test_y = [], []
            for idx, data in enumerate(train_loader):
                data = data.to(self.device)
                out = self.model(data)
                train_features.append(out)
                train_y.append(data.y)
            for idx, data in enumerate(test_loader):
                data = data.to(self.device)
                out = self.model(data)
                test_features.append(out)
                test_y.append(data.y)
            train_features = torch.cat(train_features, dim=0).cpu()
            test_features = torch.cat(test_features, dim=0).cpu()
            train_y = torch.cat(train_y, dim=0).cpu()
            test_y = torch.cat(test_y, dim=0).cpu()
        # train_features = F.normalize(train_features, dim=-1, p=2)
        # test_features = F.normalize(test_features, dim=-1, p=2)
        return train_features, test_features, train_y, test_y
    
    def feed_forward(self):
        if self.align:
            dataset = align_single(self.test_dataset, 
                                self.root, self.dim, 
                                scales=self.scales, 
                                device= self.device, 
                                save_path=os.path.dirname(self.load_path), 
                                test_name=self.model_name, 
                                gamma_scale=self.gamma_scale, 
                                force_reload=False)
        else:
            dataset = TransTUDataset(
                self.root, 
                name=self.test_dataset, 
                mode=None, 
                force_reload=True, 
                dim=self.dim
            )
        if self.test_dataset == 'DD':
            forward_set = []
            for data in dataset:
                if data.x.shape[0] < 1000:
                    forward_set.append(data)
        else:
            forward_set = dataset
        loader = DataLoader(forward_set, self.batchsize, shuffle=True)
        Z, y = [], []
        self.config_model(self.hyparam_model)
        print(f"loading model from {self.load_path}")
        self.model.load_state_dict(torch.load(self.load_path, weights_only=True, map_location=self.device), strict=False)
        self.model.to(self.device)
        self.model.eval()
        with torch.no_grad():
            for data in loader:
                y.append(data.y)
                data = data.to(self.device)
                out = self.model(data)
                Z.append(out.cpu())
        Z = torch.cat(Z, dim=0)
        Z = F.normalize(Z, p=2, dim=-1)
        y = torch.cat(y)
        return Z, y
            



    def criterion(self, out, y):
        criterion = nn.CrossEntropyLoss(reduction='mean')
        return criterion(out, y.long())

    def valid(self, loader):
        self.model.eval()
        with torch.no_grad():
            correct = 0
            losses = []
            for data in loader:
                data = data.to(self.device)
                out = self.model(data)
                loss = self.criterion(out, data.y)
                losses.append(loss.item())
                correct += (out.argmax(dim=1) == data.y).sum().item()

        return sum(losses) / len(losses), float(correct) / len(loader.dataset)
                        

    def train_epoch(self, loader):
        self.model.train()
        losses = []
        correct = 0
        for data in loader:
            data = data.to(self.device)
            out = self.model(data)
            loss = self.criterion(out, data.y)
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            losses.append(loss.item())
            correct += (out.detach().argmax(dim=1) == data.y).sum().item()

        return sum(losses) / len(losses), float(correct) / len(loader.dataset)
    
      
                
        
def main(args, test_datasets, fewshot_set, model_param):
    device = "cuda:0"
    print(device)

    print(model_param)

    align_feat=args.align
    inductive = args.inductive
    print('='*20+f"Align is {align_feat}"+'='*20)
    for test_dataset in test_datasets:
        pretrain_set = test_dataset if test_dataset in TEST else "ALL"
        pretrain_set = "ALL"

        pretrain_path = os.path.join(args.output, f'{pretrain_set}_pretrain_model_ef_epoch{args.readepoch}.pth')

        best_acc = []
        best_cluster = []
        for seed in range(5):
            seed_all(seed)

            trainer = Trainer(
                test_dataset, 
                args.batch_size, 
                device, 
                hyparam_model=model_param, 
                fewshot_param=fewshot_set, 
                dim=args.dim, 
                scales=args.scales, 
                gamma_scale=args.gammascale, 
                pretrain_path=pretrain_path,
                seed=seed, 
                model_name=pretrain_set, 
                align_feat=align_feat, 
                inductive=inductive
            )
            print(f"\nEvaluating on {test_dataset}")

        # Z, y = trainer.feed_forward()
        # _, acc = batch_zero_shot(Z, y, dataset_name=test_dataset)
        # print('Zero-shot accuracy:', acc * 100)

        # train_X, test_X, train_y, test_y = trainer.few_shot_feed_forward()
        # classifier = LogisticRegression(max_iter=1000, random_state=111, C=10.)
        # classifier.fit(train_X.numpy(), train_y.numpy())
        # test_acc = classifier.score(test_X.numpy(), test_y.numpy())
        # print(f"Dataset:{test_dataset} Accuracy: {test_acc:.4f}")
        # best_acc.append(test_acc)

            if args.cluster:
                Z, y = trainer.feed_forward()
                gs=1.
                print(f"[Cluster] gamma_scale={gs}, trials={args.cluster_trials}")
                res = cluster(Z, y, n_trials=args.cluster_trials, gamma_scale=gs)
                metric = res[args.cluster_metric]

            else:
                start = time.time()
                Z_train, Z_test, y_train, y_test = trainer.few_shot_feed_forward()
                retained_acc = classify(Z_train, Z_test, y_train, y_test, test_dataset, fewshot_set, device)
                best_acc.append(retained_acc)
                epoch = fewshot_set['epoch']
                print(f'Retained Acc: {retained_acc:.4f}, train {epoch} epochs cost {time.time() - start} s')

        
        if not args.cluster:
            print('Accuracy list:', ['{:.4f}'.format(acc) for acc in best_acc])
            mean = np.mean(best_acc)
            std = np.std(best_acc)
            print(f"Align:{align_feat} {test_dataset} Accuracy: mean: {mean:.4f}, std: {std:.4f}")





if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="downstream-evaluation"
    )

    parser.add_argument("-b", "--batch_size", default=64, type=int)
    parser.add_argument("-e", "--epoch", default=500, type=int)
    parser.add_argument("-s", "--seed", default=120810, type=int)

    parser.add_argument("-nl", "--num_layer", default=3, type=int)
    parser.add_argument("-ns", "--num_supp", default=16, type=int)

    parser.add_argument("-ls", "--lr_schedule", default=0.99, type=float)
    parser.add_argument("-lr1", "--lr1", default=5e-4, type=float)
    parser.add_argument("-lr2", "--lr2", default=1e-1, type=float)

    parser.add_argument("-lam", "--lam", default=0.1, type=float)
    parser.add_argument("-g1", "--gamma", default=100., type=float)
    parser.add_argument('-d', '--dim', default=32, type=int)
    parser.add_argument('-scales', default=[0.25, 0.5, 1., 2., 5., 10.], type=list)
    parser.add_argument('-gammascale', default=0.01, type=int)

    parser.add_argument('-r', '--root', default='./TUDataset/', type=str)
    parser.add_argument('-o', '--output', 
                        default='./ckpt/ef_ker_d32_4ds_multiGIN_ALIGN_gammascale0.01_btwn_atoms64_nsup16_negweight2_all/', 
                        type=str)
    parser.add_argument('--readepoch', default=50, type=int)

    parser.add_argument('--align', default=True, type=bool)
    parser.add_argument('--inductive', default=True, type=bool)
    parser.add_argument('--cluster', default=False, help='cluster or fewshot classification')
    parser.add_argument('--cluster_trials', default=10, type=int)
    parser.add_argument('--cluster_metric', choices=['acc_mean','nmi_mean','ari_mean'], default='acc_mean')

    fewshot_set = {
        'k': 50, 
        'epoch': 1000,   
        'lr': 0.0005
    }
    args = parser.parse_args()

    model_param = {
        "gt_param":{
            "num_encoder_layers":args.num_layer,    
            "embed_dim":args.dim * len(args.scales),
            "ffn_embed_dim":args.dim * len(args.scales),
            "num_attn_heads":4, 
            "dropout":0.1,
            "attn_dropout":0.1,
            "activation_dropout":0.1,
            "layerdrop":0.0,
            "encoder_normalize_before":False,
            "activation_fn":"gelu",
        }, 
        'gin_hidden_dim': 128, 
        "num_atoms":64,
        "num_atom_supp":args.num_supp,
        "gamma":args.gamma, 
        "readout": "mean", 
        'mlp_layers': 1,
    }
    print("#"*50)
    print(args)
    print(fewshot_set)

    test_datasets = [
        'DD', 
    ]
    main(args, test_datasets, fewshot_set, model_param)
