from embedder import embedder
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
import torch
import datetime
import utils
from tqdm import trange
from layers import GCN, GAT, SAGE, SGC
from sklearn.decomposition import PCA
from utils import add_labels
from torch_geometric.data import NeighborSampler
import pickle

class GOODIE():
    def __init__(self, args):
        self.args = args
    
    def training(self):
        file = utils.set_filename(self.args)
        logger = utils.setup_logger('./', '-', file)

        seed_result = {}
        seed_result['acc'] = []
        seed_result['macro_F'] = []
        
        for seed in trange(self.args.epoch_start, self.args.epoch_start+self.args.n_runs):
            print(f'============== seed:{seed} ==============')
            utils.seed_everything(seed)
            print('seed:', seed, file)
            self.args.seed = seed
            self = embedder(self.args, seed, logger=logger)

            # Obtain Pseudo Labels
            lp_output = self.lp_mat.to(self.args.device)
            self.pseudo_labels = lp_output.argmax(1)
           
            if self.args.lamb != 0.0:
                lp_prediction = torch.softmax(lp_output / self.args.lp_temp, 1).max(1)[0]
                # if self.args.ver in [0,2]:
                #     lp_prediction = F.normalize(lp_output, 1).max(1)[0]
                # elif self.args.ver in [1,3]:
                #     lp_prediction = torch.softmax(lp_output / self.args.lp_temp, 1).max(1)[0]
                
                lp_prediction[self.train_mask] = 1.0
                if self.args.scaled:
                    lp_pred_mat = lp_prediction
                else:
                    lp_pred_mat = lp_prediction.unsqueeze(1) @ lp_prediction.unsqueeze(1).T
            else:
                lp_pred_mat = None

            # Main training
            model = modeler(self.args).to(self.args.device)
            optimizer = optim.Adam(model.parameters(), lr=self.args.lr)

            acc_vals = []
            test_results = []
            best_metric = 0

            for epoch in range(0, self.args.epochs):
                model.train()
                optimizer.zero_grad()
                
                loss_ce, loss_pseudo = model(self.x, self.edge_index, self.labels, self.pseudo_labels, lp_pred_mat, self.train_mask, self.val_mask, self.test_mask)
                loss = loss_ce + self.args.lamb * loss_pseudo

                loss.backward()
                optimizer.step()

                # Valid
                model.eval()                
                embed = model.classifier1(self.x, self.edge_index, embed=True)
                output = model.classifier2(embed, self.edge_index)

                acc_val, macro_F_val = utils.performance(output[self.val_mask], self.labels[self.val_mask], pre='valid', evaluator=self.evaluator)

                acc_vals.append(acc_val)

                if best_metric <= acc_val:
                    best_metric = acc_val
                    max_idx = acc_vals.index(max(acc_vals))
                    best_embed = embed[:].detach().cpu().numpy()
                    best_output = output[:]

                # Test
                acc_test, macro_F_test = utils.performance(output[self.test_mask], self.labels[self.test_mask], pre='test', evaluator=self.evaluator)

                test_results.append([acc_test, macro_F_test])
                best_test_result = test_results[max_idx]

                if epoch % self.args.print_result == 0:
                    model_name = self.args.embedder
                    st = "[seed {}][{}-{}][{}][Epoch {}]".format(seed, self.args.dataset, self.args.missing_rate, model_name, epoch)
                    st += "[Val] ACC: {:.2f}, Macro-F1: {:.2f}|| ".format(acc_val, macro_F_val)
                    st += "[Test] ACC: {:.2f}, Macro-F1: {:.2f}\n".format(acc_test, macro_F_test)
                    st += "  [*Best Test Result*][Epoch {}] ACC: {:.2f}, Macro-F1: {:.2f}".format(max_idx, best_test_result[0], best_test_result[1])
                    print(st)
                      
                if (epoch - max_idx > self.args.patience) or (epoch+1 == self.args.epochs):
                    if epoch - max_idx > self.args.patience:
                        if self.args.ver > 0:
                            ver = f'_ver_{self.args.ver}'
                        else:
                            ver = ''
                        with open(f"./tsne/{self.args.dataset}_{self.args.missing_type}_mr_{self.args.missing_rate}_lamb_{self.args.lamb}_seed_{seed}{ver}.pkl", 'wb') as f:
                            pickle.dump(best_embed, f)
                        with open(f"./tsne/{self.args.dataset}_{self.args.missing_type}_mr_{self.args.missing_rate}_lamb_{self.args.lamb}_seed_{seed}_output{ver}.pkl", 'wb') as f:
                            pickle.dump(best_output.detach().cpu().numpy(), f)
                        with open(f"./tsne/{self.args.dataset}_seed_{seed}_y.pkl", 'wb') as f:
                            pickle.dump(self.labels.cpu().numpy(), f)
                        with open(f"./tsne/{self.args.dataset}_seed_{seed}_train_mask.pkl", 'wb') as f:
                            pickle.dump(self.train_mask.cpu().numpy(), f)
                        with open(f"./tsne/{self.args.dataset}_seed_{seed}_test_mask.pkl", 'wb') as f:
                            pickle.dump(self.test_mask.cpu().numpy(), f)

                        print("Early stop")
                    output = best_output
                    best_test_result[0], best_test_result[1] = utils.performance(output[self.test_mask], self.labels[self.test_mask], pre='test', evaluator=self.evaluator)
                    print("[Best Test Result] ACC: {:.2f}, Macro-F1: {:.2f}".format(best_test_result[0], best_test_result[1]))
                    # torch.cuda.empty_cache()
                    break

            seed_result['acc'].append(float(best_test_result[0]))
            seed_result['macro_F'].append(float(best_test_result[1]))

        acc = seed_result['acc']
        f1 = seed_result['macro_F']

        print('[Averaged result] ACC: {:.2f}+{:.2f}, Macro-F: {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))
        print('{:.2f}+{:.2f} {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))

        logger.info('')
        logger.info(datetime.datetime.now())
        logger.info(file)
        logger.info(f'----------- missing rate: {self.args.missing_rate} -----------')
        logger.info('{:.2f}+{:.2f} {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))
        logger.info('{:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc)))
        logger.info('{:.2f}+{:.2f}'.format(np.mean(f1), np.std(f1)))
        logger.info(self.args)
        logger.info(f'=================================')
        
        # print(self.args)


class modeler(nn.Module):
    def __init__(self, args):
        super(modeler, self).__init__()
        self.args = args

        ## Model Selection ##
        if args.gnn == 'GCN':
            classifier1 = GCN(1, args.n_feat, args.n_hid, args.n_class, normalize=True, is_add_self_loops=False)
            classifier2 = GCN(1, args.n_hid, args.n_hid, args.n_class, normalize=True, is_add_self_loops=False)
        elif args.gnn == 'GAT':
            classifier = GAT(args.n_layer, args.n_feat, args.n_hid, args.n_class, args.n_head, is_add_self_loops=False)
        elif args.gnn == "SAGE":
            classifier = SAGE(args.n_layer, args.n_feat, args.n_hid, args.n_class)
        elif args.gnn == "SGC":
            classifier = SGC(args.n_feat, args.n_class, args.n_layer, is_add_self_loops=False)
        else:
            raise NotImplementedError("Not Implemented Architecture!")        
        self.classifier1 = classifier1
        self.classifier2 = classifier2

    def forward(self, x, edge_index, labels, pseudo_labels, weight_mask, idx_train, idx_val=None, idx_test=None):
        embed = self.classifier1(x, edge_index, embed=True)
        if self.args.lamb != 0.0:
            if self.args.scaled:
                label_mat = torch.nn.functional.one_hot(pseudo_labels).T + 0.0
                # label_weight = torch.mm(label_mat, torch.diag(weight_mask))
                label_weight = label_mat * weight_mask
                centroids = torch.mm(label_weight, embed)
                pseudocon_loss = self.pseudocon_loss(centroids, scaled=True)
            else:
                pseudocon_loss = self.pseudocon_loss(embed, pseudo_labels, weight_mask=weight_mask)            
        else:
            pseudocon_loss = 0.0
        
        output = self.classifier2(embed, edge_index)

        if 'OGBN' in self.args.dataset:
            labels = labels.squeeze(1)
        loss_nodeclassification = F.cross_entropy(output[idx_train], labels[idx_train])

        return loss_nodeclassification, pseudocon_loss

    def pseudocon_loss(self, features, labels=None, mask=None, temp=0.07, base_temp=0.07, weight_mask=None, scaled=False):
        # Normalize
        features = F.normalize(features, dim=-1)
        batch_size = features.shape[0]
        if scaled:
            mask = torch.eye(batch_size, dtype=torch.float32).to(self.args.device)
        else:
            labels = labels.contiguous().view(-1, 1)
            mask = torch.eq(labels, labels.T).float()

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.args.temp)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(self.args.device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask 
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        if scaled:
            mean_log_prob_pos = log_prob.sum(1)
        else:
            # row_mask = mask.sum(1) != 0
            # mean_log_prob_pos = (weight_mask[row_mask].unsqueeze(1) * mask[row_mask] * log_prob[row_mask]).sum(1) / mask[row_mask].sum(1)
            mean_log_prob_pos = (weight_mask * mask * log_prob).sum(1) / mask.sum(1)

        loss = - (temp / base_temp) * mean_log_prob_pos
        loss = loss.mean()

        return loss