import argparse
import time
import shutil
import torch 
import os

from torch_geometric.loader import DataLoader
from model_mutiGIN import GraphModel
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

from process_data.utils import dataset_info
from process_data.split import split_tu_dataset
from modules.utils import check_nan_inf
from process_data.graph_dataset import TransTUDataset, load_graph_dataset

from plot import plot_acc

import numpy as np
import pynvml
import random

from process_data.align import align_feat
from resource_monitor import ResourceMonitor
def seed_all(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) 
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
class Trainer():
    def __init__(
        self, 
        train_datasets, 
        test_dataset,
        num_epochs, 
        batchsize, 
        device,     
        hyparam_model: dict = ..., 
        hyparam_optim: dict = ..., 
        jumping_mode: str = None,   
        align=True, 
        lam = 0.1, 
        gamma_scale = 0.001, 

        root: str = './TUDataset', 
        save_dir: str = './ckpt', 
        dim=10, 
        scales=[0.25, 0.5, 1., 2., 5., 10.]
    ):
        self.train_datasets = train_datasets
        self.test_dataset = test_dataset
        self.num_epochs = num_epochs
        self.batchsize = batchsize
        self.device = device
        self.hyparam_model = hyparam_model
        self.hyparam_optim = hyparam_optim
        self.jumping_mode = jumping_mode
        self.align = align
        self.lam = lam
        self.gamma_scale=gamma_scale

        self.root = root
        self.save_dir = save_dir
        self.dim=dim
        self.scales = scales
        
        

    def config_model(self, hp_param):
        gt_param = hp_param["gt_param"]
        self.model = GraphModel(
            hp_param["num_atoms"],
            hp_param["num_atom_supp"],
            hp_param["gamma"],
            use_mlp_head=False, 
            mlp_out_dim=None, 
            jumping_mode=self.jumping_mode, 
            readout=hp_param["readout"], 
            n_graph = len(self.scales), 
            feat_dim=self.dim, 
            gin_hidden_dim=hp_param['gin_hidden_dim'], 
            gin_num_layer = hp_param['gin_num_layer'], 
            **gt_param
        )
        
        self.model.reset_parameters()

    
    def config_optimizer(self, hp_param):

        optimizer = torch.optim.AdamW([
            {'params': self.model.gt.parameters()},
            {'params': self.model.gins.parameters()},  
            {'params': self.model.reference_layer.atoms}, 
            # {'params': self.model.mlp.parameters()}, 
            {'params': self.model.reference_layer.gamma, 'lr': hp_param['lr2']}, 
        ], lr=hp_param['lr1'], weight_decay=1e-5
        )
        self.optim = optimizer
    
    
    def config_loader(self):
        if self.align:
            train_datasets_list = align_feat(
                self.train_datasets, 
                self.root, 
                save_path=self.save_dir, 
                test_name = self.test_dataset, 
                dim=self.dim, 
                gamma_scale=self.gamma_scale,
                scales=self.scales, 
                device=self.device
                )
        else:
            train_datasets_list = [
            TransTUDataset(
                self.root, 
                name, 
                dim=self.dim, 
                scales=self.scales, 
                mode=None, 
                align_feat=False            
            ) for name in self.train_datasets
        ]            
        train_loaders = []

        for idx, dataset in enumerate(train_datasets_list):
            if self.train_datasets[idx] == 'DD':
                del_dataset = []
                for j, data in enumerate(dataset):
                    if data.x.shape[0] <= 1000:
                        del_dataset.append(data)
                    else:
                        print(f'delete {j} data in class {data.y.item()}')     
                dataset = del_dataset             
            train_loaders.append(DataLoader(dataset, self.batchsize, shuffle=True))
        return train_loaders

    
    def set_train_test_dataset(self, name):
        train_path, test_path = split_tu_dataset(name, self.root)
        train_dataset = TransTUDataset(
            root=train_path,
            name=name,
            mode='train',
            use_decomp='all_graphs',
            dim=self.dim, 
            scales=self.scales, 
            force_reload=False
        )

        test_dataset = TransTUDataset(
            root=test_path,
            name=name,
            mode='test',
            use_decomp='all_graphs',
            dim=self.dim, 
            scales=self.scales, 
            force_reload=False, 
            align_feat=False
        )
        train_set = []
        test_set = []
        if name == 'DD':
            for data in train_dataset:
                if data.x.shape[0] < 1000:
                    train_set.append(data)
            train_dataset = train_set
            for data in test_dataset:
                if data.x.shape[0] < 1000:
                    test_set.append(data)
            test_dataset = test_set
        
        train_loader = DataLoader(train_dataset, self.batchsize, shuffle=True)
        test_loader = DataLoader(test_dataset, self.batchsize, shuffle=False)

        return train_loader, test_loader
    
    def train_test(self):
        train_loader, test_loader = self.set_train_test_dataset(self.test_dataset)
        self.config_model(self.hyparam_model)
        self.config_optimizer(self.hyparam_optim)
        self.model.to(self.device)
        retained_acc = 0
        self.train_losses = []
        self.val_losses = []
        self.val_accs = []
        self.train_accs = []
        for epoch in range(self.num_epochs):
            Z_train, train_loss, train_y = self.train_epoch(train_loader)
            _, train_acc = self.predict(Z_train, Z_train, train_y, train_y)
            print(f'Dataset: {self.test_dataset}, Epoch: {epoch + 1}, train loss: {train_loss:.4f}, train acc:{train_acc:.4f}')       
            if epoch == 0 or (epoch+1) % 2 == 0:
                Z_test, test_loss, test_y = self.feed_forward(test_loader)
                _, test_acc = self.predict(Z_train, Z_test, train_y, test_y)
                if test_acc > retained_acc:
                    retained_acc = test_acc
                print(f'{self.test_dataset} test loss: {test_loss:.4f},Test acc: {test_acc:.4f}, retained acc: {retained_acc:.4f}')          
            self.train_losses.append(train_loss)
            self.train_accs.append(train_acc)
            self.val_losses.append(test_loss)
            self.val_accs.append(test_acc) 
        plot_acc(self.num_epochs, self.test_dataset, self.train_accs, self.val_accs, 
                 self.train_losses, self.val_losses,
                 root='./figures/single_graph_train')

    def train_epochfirst(self):
        self.config_model(self.hyparam_model)
        print(self.model)
        self.config_optimizer(self.hyparam_optim)
        self.model.to(self.device)

        self.model.train()
        print('Loading and Datasets...')
        train_loader_list = self.config_loader()

        for epoch in range(self.num_epochs):
            print('='*10+f"Epoch {epoch + 1} of {self.num_epochs}"+'='*10)
            for idx, train_loader in enumerate(train_loader_list):
                Z_train, train_loss, train_y = self.train_epoch(train_loader)
                _, acc = self.predict(Z_train, Z_train, train_y, train_y, n_neighbors=5)
                print(f'Train Dataset: {self.train_datasets[idx]}, Epoch: {epoch + 1}, train loss: {train_loss:.4f}, train acc:{acc:.4f}')
            if (epoch + 1) % 10 == 0:
                torch.save(self.model.state_dict(), 
                            os.path.join(self.save_dir, f'{self.test_dataset}_pretrain_model_ef_epoch{epoch+1}.pth'))
                print(f'Save model at epoch {epoch + 1}!')
            torch.cuda.empty_cache()

    def criterion_InfoNCE(self, Z, y, temperature=0.1, neg_weight=2.):
        Z = Z.to(self.device)
        y = y.to(self.device)
        Z = F.normalize(Z, dim=1)
        sim_matrix = torch.matmul(Z, Z.T)
        labels_matrix = torch.eq(y.unsqueeze(0), y.unsqueeze(1))
        labels_matrix.fill_diagonal_(False)
        
        P_size = labels_matrix.sum(dim=1)
        P_size = torch.clamp(P_size, min=1)
        
        exp_sim = torch.exp(sim_matrix / temperature)
        exp_sim_sum = exp_sim.clone()   # .clone() remains grads of the original tensor
        exp_sim_sum.fill_diagonal_(0)
        # denominator = exp_sim_sum.sum(dim=1)
        denominator = (exp_sim_sum * labels_matrix.float() + neg_weight * exp_sim_sum * (~labels_matrix).float()).sum(dim=1)
        numerator = (exp_sim * labels_matrix.float()).sum(dim=1)
        loss_per_sample = -torch.log(1.0 / P_size * numerator / denominator + 1e-7)

        loss = loss_per_sample.mean()
        return loss
    

                        
    def predict(self, Z_train, Z_test, y_train, y_test, n_neighbors=5):
        with torch.no_grad():
            classifier = KNeighborsClassifier(n_neighbors=n_neighbors)
            classifier.fit(Z_train.cpu().numpy(), y_train.cpu().numpy())
            pred = classifier.predict(Z_test.cpu().numpy())
            cnt = (torch.tensor(pred) == y_test.cpu()).sum().item()
            acc = cnt / len(y_test.cpu())
        return cnt, acc

    def train_epoch(self, loader):
        self.model.train()
        losses = []
        Z = []
        ys = []
        for step, data in enumerate(loader):
            if torch.max(data.batch) == 0:
                continue
            ys.append(data.y.detach().cpu())
            data = data.to(self.device)
            out = self.model(data)
            loss = self.criterion_InfoNCE(out, data.y)
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            self.model.reference_layer.gamma.data = torch.clamp(self.model.reference_layer.gamma.data, min=5.)
            losses.append(loss.item())
            Z.append(out.detach().cpu())
            torch.cuda.empty_cache()

        Z = torch.cat(Z, dim=0)
        y = torch.cat(ys)
        return Z, sum(losses) / len(losses), y
    
    def feed_forward(self, loader):
        self.model.to(self.device)
        self.model.eval()
        Z, losses, y = [], [], []
        with torch.no_grad():
            for data in loader:
                y.append(data.y)
                data = data.to(self.device)
                out = self.model(data)
                loss = self.criterion_InfoNCE(out, data.y)
                losses.append(loss.cpu().item())
                Z.append(out.cpu())
        Z = torch.cat(Z, dim=0)
        y = torch.cat(y)
        return Z, sum(losses) / len(losses), y
    
                
        
def main(args, train_datasets, test_datasets, model_param):
    device = "cuda"
    print(device)

    for dataset in train_datasets:
        force_reload=False

        dataset = TransTUDataset(
            root=args.root, 
            name=dataset, 
            use_node_attr=True, 
            use_decomp='all_graphs', 
            dim=args.dim, 
            scales=args.scales, 
            mode=None, 
            force_reload=force_reload, 
            align_feat=False
        )
        dataset_info(dataset)

    for test_dataset in test_datasets:

        if not os.path.exists(args.root):
            os.makedirs(args.root)

        if not os.path.exists(args.output):
            os.makedirs(args.output)
        
        test_dataset = 'ALL'
        
        train_list = [dataset for dataset in train_datasets if dataset != test_dataset]
        print(f'\nPre-training without {test_dataset}')
        trainer = Trainer(
            train_list, 
            test_dataset,
            args.epoch, 
            args.batch_size, 
            device, 
            hyparam_model=model_param, 
            hyparam_optim={"lr1":args.lr1, "lr2":args.lr2, "lr_schedule":args.lr_schedule}, 
            dim=args.dim, 
            scales=args.scales, 
            align=args.align, 
            lam=args.lam, 
            gamma_scale=args.gamma_scale, 
            root=args.root, 
            save_dir=args.output
        )

        trainer.train_epochfirst()
        # trainer.train_test()
        print("-" * 100)
        break
    


    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="pre-training"
    )

    parser.add_argument("-b", "--batch_size", default=64, type=int)
    parser.add_argument("-e", "--epoch", default=50, type=int)
    parser.add_argument("-s", "--seed", default=555, type=int)
    
    parser.add_argument("-nl", "--num_layer", default=3, type=int)
    parser.add_argument("-ns", "--num_supp", default=16, type=int)
    
    parser.add_argument("-ls", "--lr_schedule", default=0.99, type=float)
    parser.add_argument("-lr1", "--lr1", default=5e-4, type=float)
    parser.add_argument("-lr2", "--lr2", default=1e-1, type=float)
    
    parser.add_argument("-g1", "--gamma", default=100., type=float)
    parser.add_argument('-d', '--dim', default=32, type=int)
    parser.add_argument('--scales', default=[0.25, 0.5, 1., 2., 5., 10.], type=list)
    parser.add_argument("-lam", "--lam", default=0.1, type=float)
    parser.add_argument('--align', default=True, type=bool)
    parser.add_argument("--gamma_scale", default=0.01, type=float)

    parser.add_argument('-r', '--root', default='./TUDataset', type=str)
    parser.add_argument('-o', '--output', 
                        default='./ckpt/ef_ker_d32_4ds_multiGIN_ALIGN_gammascale0.01_btwn_atoms64_nsup16_negweight2_all', 
                        type=str)
  
    args = parser.parse_args()
    model_param = {
        "gt_param":{
            "num_encoder_layers":args.num_layer,
            "embed_dim":args.dim * len(args.scales),
            "ffn_embed_dim":args.dim * len(args.scales),
            "num_attn_heads":4, 
            "dropout":0.1,
            "attn_dropout":0.1,
            "activation_dropout":0.1,
            "layerdrop":0.0,
            "encoder_normalize_before":False,
            "activation_fn":"gelu",
        }, 
        'gin_hidden_dim': 128, 
        "num_atoms":64,
        "num_atom_supp":args.num_supp,
        "gamma":args.gamma, 
        "readout": "mean", 
        'gin_num_layer': 3, 
    }
    seed_all(args.seed)  
    
    print("#"*50)
    print(args)
    print(model_param)

    train_datasets=[
        'ENZYMES',
        "NCI1", 
        'NCI109', 
        'DD', 
        'Mutagenicity',
    ]

    test_datasets = [
        'ENZYMES', 
        "NCI1", 
        'NCI109', 
        'DD', 
        'Mutagenicity',        
    ]

    monitor = ResourceMonitor()
    with monitor:
        main(args, test_datasets, test_datasets, model_param)
    monitor.print_summary()
