import argparse
import copy
from datetime import datetime

import numpy as np
import torch
import torch.nn.functional as F

from data_loader import load_data
from model import Adaptor, GCN, GCL, AGG
from graph_learners import *
from utils import *
from params import *
from augment import *
from sklearn.cluster import KMeans
from kmeans_pytorch import kmeans as KMeans_py
from sklearn.metrics import f1_score
import os

import random

EOS = 1e-10
args = set_params()

class Experiment:
    def __init__(self):
        super(Experiment, self).__init__()
        self.training = False

    def setup_seed(self, seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        np.random.seed(seed)
        random.seed(seed)


    def test_cls(self, model, mask, features, adj, labels):
        logits = model(features, adj)
        logp = F.log_softmax(logits, 1)
        probabilities = torch.exp(logp)
        loss = F.nll_loss(logp[mask], labels[mask], reduction='mean')
        test_accu = accuracy(logp[mask], labels[mask])

        preds = torch.argmax(logp, dim=1)
        test_f1_macro = torch.tensor(f1_score(labels[mask].cpu(), preds[mask].cpu(), average='macro'))
        test_f1_micro = torch.tensor(f1_score(labels[mask].cpu(), preds[mask].cpu(), average='micro'))
        auc = AUC(probabilities[mask], labels[mask])

        return loss, test_accu, auc, test_f1_macro, test_f1_micro

    def loss_cls(self, model, mask, features, adj, labels, mode):
        logits = model(features, adj)
        logp = F.softmax(logits, 1)
        probabilities = logp
        if mode == 'train':
            loss = F.cross_entropy(logp, labels[mask])
            accu = accuracy(logp, labels[mask])
            auc = AUC(probabilities, labels[mask])
        else:
            loss = F.cross_entropy(logp[mask], labels[mask])
            accu = accuracy(logp[mask], labels[mask])
            auc = AUC(probabilities[mask], labels[mask])
        return loss, accu, auc

    def evaluate_adj_by_cls(self, Adj, learned_fused_adj_test, features, features_test, nfeats, labels, nclasses, train_mask, test_mask, args):

        model = GCN(in_channels=nfeats*2, hidden_channels=args.hidden_dim_cls, out_channels=nclasses, num_layers=args.nlayers_cls,
                    dropout=args.dropout_cls, dropout_adj=args.dropedge_cls, sparse=args.sparse)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_cls, weight_decay=args.w_decay_cls)


        bad_counter = 0
        best_val = 0
        best_model = None

        if torch.cuda.is_available():
            model = model.cuda()
            train_mask = train_mask.cuda()
            test_mask = test_mask.cuda()

            features = torch.cat(features, dim=1).cuda()
            features_test = torch.cat(features_test, dim=1).cuda()
            labels = labels.cuda()

        for epoch in range(1, args.epochs_cls + 1):
            model.train()
            loss, train_accu, train_auc = self.loss_cls(model, train_mask, features, Adj, labels, mode='train')
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if epoch % 10 == 0:
                with torch.no_grad():
                    model.eval()
                    val_loss, val_accu, val_auc = self.loss_cls(model, test_mask, features_test,learned_fused_adj_test, labels, mode='test')

                    print("------Epoch {:05d} | Train Loss {:.4f} | Train AUC  {:.4f}| Val Loss {:.4f} | Val ACC {:.4f} | Val AUC  {:.4f}"\
                          .format(epoch, loss.item(), train_auc, val_loss.item(), val_accu.item(),val_auc))
                    if val_accu > best_val:
                        bad_counter = 0
                        best_val = val_accu
                        best_model = copy.deepcopy(model)
                    else:
                        bad_counter += 1

        best_model.eval()
        test_loss, test_accu, test_auc, test_f1_macro, test_f1_micro = self.test_cls(best_model, test_mask, features_test, learned_fused_adj_test, labels)
        return best_val, test_accu, test_auc, test_f1_macro, test_f1_micro


    def loss_gcl(self, model, specific_graph_learner, fused_graph_learner, features, adjs,
                 optimizer_cl, optimizer_leaner, discriminator=None):
        # optimizer.zero_grad()
        optimizer_cl.zero_grad()
        optimizer_leaner.zero_grad()

        learned_specific_adjs = []
        for i in range(len(adjs)):
            specific_adjs_embedding = specific_graph_learner[i](features[i])
            learned_specific_adj = specific_graph_learner[i].graph_process(specific_adjs_embedding)
            learned_specific_adjs.append(learned_specific_adj)

        fused_embedding = fused_graph_learner(torch.cat(features, dim=1))
        learned_fused_adj = fused_graph_learner.graph_process(fused_embedding)

        z_specific_adjs = [model(features[i], learned_specific_adjs[i], mode="omic" if i == 0 else 'path') for i in range(len(adjs))]
        z_fused_adjs = model(fused_embedding, learned_fused_adj, mode="fused") 


       
        adjs_aug = graph_augment(adjs, args.dropedge_rate, training=self.training, sparse=args.sparse)
        if args.sparse:
            for i in range(len(adjs)):
                adjs_aug[i].edata['w'] = adjs_aug[i].edata['w'].detach()
        else:
            adjs_aug = [a.detach() for a in adjs_aug]


        z_aug_adjs = [model(features[i], adjs_aug[i], mode="omic" if i == 0 else 'path') for i in range(len(adjs))]


        if args.contrast_batch_size:
            node_idxs = list(range(features[0].shape[0]))
            random.shuffle(node_idxs)
            batches = split_batch(node_idxs, args.contrast_batch_size)
            loss = 0
            for batch in batches:
                weight = len(batch) / features[0].shape[0]
                loss += model.cal_loss([z[batch] for z in z_specific_adjs], [z[batch] for z in z_aug_adjs], [z[batch] for z in z_fused_adjs]) * weight
        else:
            loss_fused, loss_smi, loss_umi, loss_aug_fused, loss = model.cal_loss(z_specific_adjs, z_aug_adjs, z_fused_adjs)

        loss.backward()
        optimizer_cl.step()
        optimizer_leaner.step()

        return loss_fused, loss_smi, loss_umi, loss_aug_fused, loss

    def train(self, args):
        print(args)
        os.makedirs(args.save_path, exist_ok=True)
        torch.cuda.set_device(args.gpu)

        features_pt_original, features_train_original, features_test_original, \
        nfeats, train_mask, test_mask, \
        adjs_pt_original, adjs_train_original, adjs_test_original, \
        labels, nclasses, omic_sizes = load_data(args)
        
        if args.downstream_task == 'classification':
            test_accuracies = []
            test_maf1 = []
            test_mif1 = []
            validation_accuracies = []

        for trial in range(args.ntrials):

            self.setup_seed(trial)


            adjs_pt = copy.deepcopy(adjs_pt_original)
            features_pt = copy.deepcopy(features_pt_original)

            adjs_train = copy.deepcopy(adjs_train_original)
            features_train = copy.deepcopy(features_train_original)

            adjs_test = copy.deepcopy(adjs_test_original)
            features_test = copy.deepcopy(features_test_original)
            

            specific_graph_learner = [ATT_learner(2, nfeats, args.k, 6, args.dropedge_rate, args.sparse, args.activation_learner) for _ in range(len(adjs_pt))]
            fused_graph_learner = ATT_learner(2, nfeats*len(adjs_pt), args.k, 6, args.dropedge_rate, args.sparse, args.activation_learner)
            omic_adaptor = Adaptor(omic_sizes)
            model = GCL(nlayers=args.nlayers, in_dim=nfeats, hidden_dim=args.hidden_dim,
                         emb_dim=args.rep_dim, proj_dim=args.proj_dim,
                         dropout=args.dropout, sparse=args.sparse, num_g=len(adjs_pt))

            optimizer_cl = torch.optim.Adam([{'params': model.parameters()}], lr=args.lr, weight_decay=args.w_decay)
            
            optimizer_leaner = torch.optim.Adam([{'params': omic_adaptor.parameters()}] +
                                        [{'params': specific_graph_learner[i].parameters()} for i in range(len(adjs_pt))] +
                                        [{'params': fused_graph_learner.parameters()}], lr=args.lr, weight_decay=args.w_decay)
            

            n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) \
                            + sum(p.numel() for p in specific_graph_learner[0].parameters() if p.requires_grad) * len(adjs_pt) \
                            + sum(p.numel() for p in fused_graph_learner.parameters() if p.requires_grad) \
                            + sum(p.numel() for p in omic_adaptor.parameters() if p.requires_grad)
            print('number of requires_grad params (M): %.2f' % (n_parameters / 1.e6))
            

            if torch.cuda.is_available():
                omic_adaptor = omic_adaptor.cuda()
                model = model.cuda()
                specific_graph_learner = [m.cuda() for m in specific_graph_learner]
                fused_graph_learner = fused_graph_learner.cuda()
               
                train_mask = train_mask.cuda()
                test_mask = test_mask.cuda()
                features_pt_path = features_pt[1].cuda()
                features_train_path = features_train[1].cuda()
                features_test_path = features_test[1].cuda()

                
                features_pt_omic_list = [[torch.tensor(omic_feature).cuda() for omic_feature in f] for f in features_pt[0]]
                features_train_omic_list = [[torch.tensor(omic_feature).cuda() for omic_feature in f] for f in features_train[0]]
                features_test_omic_list = [[torch.tensor(omic_feature).cuda() for omic_feature in f] for f in features_test[0]]


                labels = labels.cuda()
                adjs_pt = [adj.cuda() for adj in adjs_pt]
                adjs_train = [adj.cuda() for adj in adjs_train]
                adjs_test = [adj.cuda() for adj in adjs_test]

            for epoch in range(1, args.epochs + 1):
                omic_adaptor.train()
                omic_feature_pt = omic_adaptor(features_pt_omic_list)

                model.train()
                [learner.train() for learner in specific_graph_learner]
                fused_graph_learner.train()
                self.training = True
                features_pt_new = [omic_feature_pt, features_pt_path]
                loss_fused, loss_smi, loss_umi, loss_aug_fused, loss = self.loss_gcl(model, specific_graph_learner, fused_graph_learner, features_pt_new, adjs_pt,
                                    optimizer_cl, optimizer_leaner)
                print("Epoch {:05d} | CL Loss {:.4f}| fused Loss {:.4f}| smi Loss {:.4f}| umi Loss {:.4f}| aug_fused Loss {:.4f}"\
                      .format(epoch, loss.item(), loss_fused.item(), loss_smi.item(), loss_umi.item(), loss_aug_fused.item()))


                if (epoch) % 10 == 0:
                    torch.save({
                        'epoch': epoch,
                        'omic_graph_learner': specific_graph_learner[0].state_dict(),
                        'path_graph_learner': specific_graph_learner[1].state_dict(),
                        'fused_graph_learner': fused_graph_learner.state_dict(),
                        'omic_adaptor': omic_adaptor.state_dict(),
                        'GCL_model': model.state_dict(),
                    }, os.path.join(args.save_path, f'{epoch}_checkpoint.pth.tar'))


                if epoch % args.eval_freq == 0:

                    # get learned graph
                    omic_adaptor.eval()
                    features_train_omic = omic_adaptor(features_train_omic_list)
                    features_test_omic = omic_adaptor(features_test_omic_list)
                    features_train_omic = features_train_omic.detach()
                    features_test_omic = features_test_omic.detach()


                    model.eval()
                    [learner.eval() for learner in specific_graph_learner]
                    fused_graph_learner.eval()
                    self.training = False
                    features_train_new = [features_train_omic, features_train_path]
                    fused_embedding = fused_graph_learner(torch.cat(features_train_new, dim=1))
                    learned_fused_adj = fused_graph_learner.graph_process(fused_embedding)

                    features_test_new = [features_test_omic, features_test_path]
                    fused_embedding_test = fused_graph_learner(torch.cat(features_test_new, dim=1))
                    learned_fused_adj_test = fused_graph_learner.graph_process(fused_embedding_test)

                    omic_learned_embedding = specific_graph_learner[0](features_train_new[0])
                    learned_omic_specific_adj = specific_graph_learner[0].graph_process(omic_learned_embedding)

                    path_learned_embedding = specific_graph_learner[1](features_train_new[1])
                    learned_path_specific_adj = specific_graph_learner[1].graph_process(path_learned_embedding)

                    if args.sparse:
                        learned_fused_adj.edata['w'] = learned_fused_adj.edata['w'].detach()
                    else:
                        learned_fused_adj = learned_fused_adj.detach()
                        learned_omic_specific_adj = learned_omic_specific_adj.detach()
                        learned_path_specific_adj = learned_path_specific_adj.detach()


                    if args.downstream_task == 'classification':

                        f_adj = learned_fused_adj

                        val_accu, test_accu, test_auc, test_f1_macro, test_f1_micro \
                            = self.evaluate_adj_by_cls(f_adj, learned_fused_adj_test, features_train_new, features_test_new, nfeats, labels, nclasses, train_mask, test_mask, args)
                        print('EPOCH:', epoch, ' val_acc:', val_accu, ' test_acc', test_accu,' test_auc', test_auc, ' test maf1:', test_f1_macro, ' test mif1:', test_f1_micro)
                        current_time = datetime.now().strftime('%Y-%m-%d %H')
                        fh = open(args.out_path + "result_" + args.dataset + f"_Class_{current_time}.txt", "a")
                        fh.write(
                            'Trial=%f, Epoch=%f, test_auc=%f, test_accu=%f, test_f1_macro=%f,  test_f1_micro=%f' % (trial, epoch, test_auc, test_accu, test_f1_macro, test_f1_micro))
                        fh.write('\r\n')
                        fh.flush()
                        fh.close()

                    if args.downstream_task == 'clustering':
                        embedding_ = model(features_pt, learned_fused_adj)
                        embedding_ = embedding_.detach()

                        acc_mr, nmi_mr, f1_mr, ari_mr = [], [], [], []
                        for clu_trial in range(10):
                            if args.sparse:
                                embedding = embedding_
                                y_pred, _ = KMeans_py(X=embedding, num_clusters=nclasses, distance='euclidean',
                                                   device='cuda')
                                predict_labels = y_pred.cpu().numpy()
                            else:
                                embedding = embedding_.cpu().numpy()
                                kmeans = KMeans(n_clusters=nclasses, random_state=clu_trial).fit(embedding)
                                predict_labels = kmeans.predict(embedding)
                            cm_all = clustering_metrics(labels.cpu().numpy(), predict_labels)
                            acc_, nmi_, f1_, ari_ = cm_all.evaluationClusterModelFromLabel(print_results=False)
                            acc_mr.append(acc_)
                            nmi_mr.append(nmi_)
                            f1_mr.append(f1_)
                            ari_mr.append(ari_)

                        acc, nmi, f1, ari = np.mean(acc_mr), np.mean(nmi_mr), np.mean(f1_mr), np.mean(ari_mr)
                        print("Epoch {:05d} | acc {:.4f} | f1 {:.4f} | nmi {:.4f} | ari {:.4f}".format(epoch, acc, f1, nmi, ari))

                        fh = open("result_" + args.dataset + "_NMI&ARI.txt", "a")
                        fh.write(
                            'Trial=%f, Epoch=%f, ACC=%f, f1_macro=%f,  NMI=%f, ADJ_RAND_SCORE=%f' % (trial, epoch, acc, f1, nmi, ari))
                        fh.write('\r\n')
                        fh.flush()
                        fh.close()

            self.training = False
            if args.downstream_task == 'classification':
                validation_accuracies.append(val_accu.item())
                test_accuracies.append(test_accu.item())
                test_maf1.append(test_f1_macro.item())
                test_mif1.append(test_f1_micro.item())
                print("Trial: ", trial + 1)
                print("Best val ACC: ", val_accu.item())
                print("Best test ACC: ", test_accu.item())
                print("Best test MaF1: ", test_f1_macro.item())
                print("Best test MiF1: ", test_f1_micro.item())
            elif args.downstream_task == 'clustering':
                print("Final ACC: ", acc)
                print("Final NMI: ", nmi)
                print("Final F-score: ", f1)
                print("Final ARI: ", ari)

        if args.downstream_task == 'classification' and trial != 0:
            self.print_results(validation_accuracies, test_accuracies, test_maf1, test_mif1)


    def print_results(self, validation_accu, test_accu, test_maf1, test_mif1):
        s_val = "Val accuracy: {:.4f} +/- {:.4f}".format(np.mean(validation_accu), np.std(validation_accu))
        s_test = "Test accuracy: {:.4f} +/- {:.4f}".format(np.mean(test_accu),np.std(test_accu))
        maf1_test = "Test maf1: {:.4f} +/- {:.4f}".format(np.mean(test_maf1),np.std(test_maf1))
        mif1_test = "Test mif1: {:.4f} +/- {:.4f}".format(np.mean(test_mif1),np.std(test_mif1))
        print(s_val)
        print(s_test)
        print(maf1_test)
        print(mif1_test)

        fh = open(f"{args.out_path}" + args.dataset + "_Class.txt", "a")
        fh.write("Test accuracy: {:.4f} +/- {:.4f}".format(np.mean(test_accu),np.std(test_accu)))
        fh.write('\r\n')
        fh.write("Test maf1: {:.4f} +/- {:.4f}".format(np.mean(test_maf1),np.std(test_maf1)))
        fh.write('\r\n')
        fh.write("Test mif1: {:.4f} +/- {:.4f}".format(np.mean(test_mif1),np.std(test_mif1)))
        fh.write('\r\n')
        fh.flush()
        fh.close()


if __name__ == '__main__':

        experiment = Experiment()
        experiment.train(args)