import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np

from dig.auggraph.method.SMixup.model.GCN import GCN
from dig.auggraph.method.SMixup.model.GIN import GIN
from dig.auggraph.method.SMixup.model.GraphMatching import GraphMatching
from dig.auggraph.method.SMixup.utils.sinkhorn import Sinkhorn
from dig.auggraph.method.SMixup.utils.utils import NormalizedDegree, triplet_loss
from dig.auggraph.dataset.aug_dataset import TripleSet

from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_dense_adj, dense_to_sparse, degree
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch, Data
from torch.nn.functional import softmax

from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score

from utils import set_seed
from dataset.existing_datasets import GraphPropertyPredictionDataset


class smixup():
    r"""
    The S-Mixup from the `"Graph Mixup with Soft Alignments" <https://icml.cc/virtual/2023/poster/24930>`_ paper.
    
    Args:
        data_root_path (string): Directory where datasets are saved. 
        dataset (string): Dataset Name.
        conf (dict): Hyperparameters of the graph matching network which is used to compute the soft alignments.
    """
    
    def __init__(self, data_name, dataset, GMNET_conf):
        self.data_name = data_name
        self.dataset = dataset
        self.num_cls = len(self.dataset.y.unique())
        self.GMNET_conf = GMNET_conf
        self._get_GMNET(self.GMNET_conf['nlayers'], self.GMNET_conf['nhidden'], self.GMNET_conf['bs'], self.GMNET_conf['lr'], self.GMNET_conf['epochs'])
      
    def _get_GMNET(self, GMNET_nlayers, GMNET_hidden, GMNET_bs, GMNET_lr, GMNet_epochs):  
        conf = {}
        conf_dis_param = {}
        conf_dis_param['num_layers'] = GMNET_nlayers
        conf_dis_param['hidden'] = GMNET_hidden
        conf_dis_param['model_type'] = 'gmnet'
        conf_dis_param['pool_type'] = 'sum'
        conf_dis_param['fuse_type'] = 'abs_diff'
        conf['dis_param'] = conf_dis_param

        conf['batch_size'] = GMNET_bs
        conf['start_lr'] = GMNET_lr
        conf['factor'] = 0.5
        conf['patience'] = 5
        conf['min_lr'] = 0.0000001
        conf['pre_train_path'] = None
        conf['max_num_epochs'] = GMNet_epochs

        conf['dis_param']['in_dim'] = self.dataset[0].x.shape[1]
        
        self.GMNET = GraphMatching(**conf['dis_param'])        
    
    def train_test(self, batch_size, cls_model, cls_nlayers, cls_hidden, cls_dropout, cls_lr, cls_epochs, alpha, ckpt_path, sim_method = 'cos'):
        r"""
        This method first train a GMNET and then use the GMNET to perform S-Mixup. 
        
        Args:
            batch_size (int): Batch size of training the classifier.
            cls_model (string): Use GCN or GIN as the backbone of the classifier. 
            cls_nlayers (int): Number of GNN layers of the classifier.
            cls_hidden (int): Number of hidden units of the classifier.
            cls_dropout (float): Dropout ratio of the classifier.
            cls_lr (float): Initial learning rate of training the classifier. 
            cls_epochs (int): Training epochs of the classifier.
            alpha (float): Mixup ratio.
            ckpt_path (string): Location for saving checkpoints. 
            sim_method (string): Similarity function used to compute the assignment matrix. (default: :obj:`cos`)
        """
        criterion = nn.CrossEntropyLoss()
        test_accs = []
        kf = KFold(n_splits=10, shuffle=True)
        for i, (train_idx, test_idx) in enumerate(kf.split(list(range(len(self.dataset))))):
            train_idx, val_idx = train_test_split(train_idx, test_size=0.1)
            train_set, val_set, test_set = self.dataset[train_idx.tolist()], self.dataset[val_idx.tolist()], self.dataset[test_idx.tolist()]
            
            self.train_GMNET(train_set)
            
            train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = 8)
            val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers = 8)
            test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers = 8)

            if (cls_model == 'GCN'):
                model = GCN(self.dataset[0].x.shape[1], self.num_cls, cls_nlayers, cls_hidden, cls_dropout)
            elif (cls_model == 'GIN'):
                model = GIN(self.dataset[0].x.shape[1], self.num_cls, cls_nlayers, cls_hidden, cls_dropout)
            model.cuda()
            
            optimizer = optim.Adam(model.parameters(),
                                lr=cls_lr, weight_decay=1e-5)
            
            best_acc = 0.0         
            for epoch in range(1, cls_epochs + 1):
                loss_accum = 0.0
                for step, batch in enumerate(train_loader):
                    batch = self.Mixup(batch, alpha = alpha, sim_method = sim_method)

                    model.train()
                    optimizer.zero_grad()

                    _, output = model(batch)
                    
                    loss = batch.lam * criterion(output, batch.y1.long()) + (1-batch.lam) * criterion(output, batch.y2.long())
                    loss.backward()
                    optimizer.step()

                    loss_accum += loss.item()

                train_loss = loss_accum / (step + 1)
                print("Epoch [{}] Train_loss {}".format(epoch, train_loss))

                y_label = []
                y_pred = []
                
                for step, batch in enumerate(val_loader):
                    batch = batch.cuda()
                    model.eval()
                    _, output = model(batch)                   
                    pred = torch.argmax(output, dim = 1).long()
                    y_pred = y_pred + pred.cpu().detach().numpy().flatten().tolist()
                    y = batch.y.long()
                    y_label = y_label + y.cpu().detach().numpy().flatten().tolist()

                acc_val = accuracy_score(y_pred, y_label)
                print("Epoch [{}] Test results:".format(epoch),
                        "acc_val: {:.4f}".format(acc_val),)
                
                if acc_val >= best_acc:
                    best_acc = acc_val
                    torch.save(model.state_dict(), ckpt_path + f"/{self.data_name}_best_val.pth")
            
            model.load_state_dict(torch.load(ckpt_path + f"/{self.data_name}_best_val.pth"))
            
            y_label = []
            y_pred = []
            for step, batch in enumerate(test_loader):
                batch = batch.cuda()
                model.eval()
                _, output = model(batch)
                pred = torch.argmax(output, dim = 1).long()
                y_pred = y_pred + pred.cpu().detach().numpy().flatten().tolist()
                y = batch.y.long()
                y_label = y_label + y.cpu().detach().numpy().flatten().tolist()
                
            acc_test = accuracy_score(y_pred, y_label)
            
            print("Split {}: acc_test: {:.4f}".format(i, acc_test),)

            test_accs.append(acc_test)
            
        print("Final result: acc_test: {:.4f}+-{:.4f}".format(np.mean(test_accs), np.std(test_accs)))


        
    def train_GMNET(self, train_set):
        self.GMNET.cuda()
        
        train_set = TripleSet(train_set)
        train_loader = DataLoader(train_set, batch_size = self.GMNET_conf['bs'], shuffle = True, num_workers = 8)
        optimizer = optim.Adam(self.GMNET.parameters(), lr = self.GMNET_conf['lr'], weight_decay=1e-4)
        
        for epoch in range(1, self.GMNET_conf['epochs'] + 1):
            print("====epoch {} ====".format(epoch))
            self.GMNET.train()
            train_loss = 0.0
            for data_batch in train_loader:
                anchor_data, pos_data, neg_data = data_batch
                anchor_data, pos_data, neg_data = anchor_data.cuda(), pos_data.cuda(), neg_data.cuda()

                optimizer.zero_grad()
                
                x_1, y = self.GMNET(anchor_data, pos_data, pred_head = False)
                x_2, z = self.GMNET(anchor_data, neg_data, pred_head = False)
                
                loss = triplet_loss(x_1, y, x_2, z)
                loss = torch.mean(loss)
                loss.backward()
                optimizer.step()
                train_loss += loss
                
            print("Epoch [{}] Train_loss {}".format(epoch, train_loss / len(train_loader)))
        
        print("GMNET training done.")
        self.GMNET.eval()
        
        
    def Mixup(self, batch, alpha, sim_method = 'cos', normalize_method = 'softmax', temperature = 1.0,):    
        
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 0.5
            
        lam = max(lam, 1 - lam)
        
        batch = batch.cuda()
        
        batch1 = batch.clone()

        data_list = list(batch1.to_data_list())
        
        import random
        data_list2 = data_list.copy()
        random.shuffle(data_list2)
        
        batch2 =  Batch.from_data_list(data_list2).cuda()

        h1, h2 = self.GMNET.dis_encoder(batch1, batch2, node_emd = True)
        h1, h2 = h1.detach(), h2.detach()
        
        for i in range(len(data_list)):
            data_list[i].emb = h1[batch1._slice_dict['x'][i] : batch1._slice_dict['x'][i + 1],:]

            data_list2[i].emb = h2[batch2._slice_dict['x'][i] : batch2._slice_dict['x'][i + 1],:]

        batch_size = len(data_list)
        
        mixed_data_list = []

        for i in range(len(data_list)):
            # match = data_list[i].emb @ data_list2[i].emb.T
            if sim_method == 'cos':
                emb1 = data_list[i].emb / data_list[i].emb.norm(dim = 1)[:,None]
                emb2 = data_list2[i].emb / data_list2[i].emb.norm(dim = 1)[:,None]
                match = emb1 @ emb2.T / temperature 
            elif sim_method == 'abs_diff':
                match = -(data_list[i].emb.unsqueeze(1) - data_list2[i].emb.unsqueeze(0)).norm(dim = -1)

            if (normalize_method == 'softmax'):
                normalized_match = softmax(match.detach().clone(), dim = 0)
            elif(normalize_method == 'sinkhorn'):
                normalized_match = Sinkhorn(match.detach().clone())

            mixed_adj = lam * to_dense_adj(data_list[i].edge_index, max_num_nodes=data_list[i].num_nodes)[0].double()+ (1-lam) * normalized_match.double() @ to_dense_adj(data_list2[i].edge_index, max_num_nodes=data_list2[i].num_nodes)[0].double() @ normalized_match.double().T

            mixed_adj[mixed_adj < 0.1] = 0
            
            mixed_x = lam * data_list[i].x + (1-lam) * normalized_match.float() @ data_list2[i].x

            edge_index, edge_weights = dense_to_sparse(mixed_adj)
            
            data = Data(x = mixed_x.float(), edge_index = edge_index, edge_weights = edge_weights, y1 = data_list[i].y, y2 = data_list2[i].y)

            mixed_data_list.append(data)
            
        b = Batch.from_data_list(mixed_data_list)
        b.lam = lam
        return b


import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='NCI1')
    parser.add_argument('--GMNET_nlayers', type=int, default=5,
                        help='Number of layers of GMNET')
    parser.add_argument('--GMNET_hidden', type=int, default=100,
                        help='Number of hidden units of GMNET')
    parser.add_argument('--GMNET_bs', type=int, default=256,
                        help='Batch size of training GMNET')
    parser.add_argument('--GMNET_lr', type=float, default=1e-3,
                        help='Initial learning rate of GMNET')
    parser.add_argument('--GMNET_epochs', type=int, default=10,
                        help='Number of epochs to train GMNET')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='Batch size during training')
    parser.add_argument('--model', type = str, default='GIN', 
                        choices=['GCN', 'GIN'])   
    parser.add_argument('--nlayers', type = int, default = 5,
                        help='Number of GNN layers.')  
    parser.add_argument('--hidden', type=int, default=64,
                        help='Number of hidden units.')
    parser.add_argument('--dropout', type = float, default = 0.5,
                        help='Dropout ratio.') 
    parser.add_argument('--lr', type=float, default=1e-2,
                        help='Initial learning rate.')
    parser.add_argument('--epochs', type=int, default=100,
                        help='Number of epochs to train the classifier.')
    parser.add_argument('--alpha', type = float, default = 0.1,
                        help='mixup ratio.') 
    parser.add_argument('--ckpt_path', type=str, default='./archived/smixup_ckpts/', 
                        help='Location for saving checkpoints')
    parser.add_argument('--seed', type = int, default = 1,
                        help='seed')
    parser.add_argument('--sim_method', type=str, default='cos')

    args = parser.parse_args()

    set_seed(args.seed)
    cluster_path = '../../../data/misc/full_network_repository/processed/entr10_dens0.1_denm0_degm3_degv3_node4000_edge50000-feat_none-all_10clusters_ne.pt'
    kmeans, _, _, scaler_dict = torch.load(cluster_path)
    dataset = GraphPropertyPredictionDataset(
        data_name=args.dataset, dir_path='../../../data', scaler_dict=scaler_dict,
        kmeans=kmeans, extract_attributes=False, one_hot=False
    )

    GMNET_conf = {}
    GMNET_conf['nlayers'] = args.GMNET_nlayers
    GMNET_conf['nhidden'] = args.GMNET_hidden
    GMNET_conf['bs'] = args.GMNET_bs
    GMNET_conf['lr'] = args.GMNET_lr
    GMNET_conf['epochs'] = args.GMNET_epochs
        
    runner = smixup(args.dataset, dataset, GMNET_conf)

    runner.train_test(args.batch_size, args.model, cls_nlayers=args.nlayers, 
                    cls_hidden=args.hidden, cls_dropout=args.dropout, cls_lr=args.lr,
                    cls_epochs=args.epochs, alpha=args.alpha, ckpt_path=args.ckpt_path,
                    sim_method=args.sim_method)