from embedder import embedder
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
import torch
import datetime
import utils
from tqdm import trange
from layers import GCN, GAT, SAGE, SGC
from sklearn.decomposition import PCA
from utils import add_labels
from torch_geometric.data import NeighborSampler

class GOODIE():
    def __init__(self, args):
        self.args = args
    
    def training(self):
        file = utils.set_filename(self.args)
        logger = utils.setup_logger('./', '-', file)

        seed_result = {}
        seed_result['acc'] = []
        seed_result['macro_F'] = []
        
        for seed in trange(0, 0+self.args.n_runs):
            print(f'============== seed:{seed} ==============')
            utils.seed_everything(seed)
            print('seed:', seed, file)
            self.args.seed = seed
            self = embedder(self.args, seed)

            # Obtain Pseudo Labels
            lp_output = self.lp_mat.to(self.args.device)
            self.pseudo_labels = lp_output.argmax(1)
           
            if self.args.lamb != 0.0: 
                lp_prediction = torch.softmax(lp_output / self.args.lp_temp, 1).max(1)[0]
                lp_prediction[self.train_mask] = 1.0
                if self.args.scaled:
                    lp_pred_mat = lp_prediction
                else:
                    lp_pred_mat = lp_prediction.unsqueeze(1) @ lp_prediction.unsqueeze(1).T
            else:
                lp_pred_mat = None

            # Main training
            model = modeler(self.args).to(self.args.device)
            optimizer = optim.Adam(model.parameters(), lr=self.args.lr)

            acc_vals = []
            test_results = []
            best_metric = 0

            for epoch in range(0, self.args.epochs):
                model.train()
                
                # Prepare loader for Large Datasets
                if self.train_loader:
                    total_loss = 0
                    idx = 0
                    for batch_size, n_id, adjs in self.train_loader:
                        optimizer.zero_grad()
                        idx += 1
                        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
                        adjs = [adj.to(self.args.device) for adj in adjs]
                        x_batch = self.x[n_id]

                        y_pred = model.classifier(x_batch, adjs, sample=True)
                        y_true = self.labels[n_id[:batch_size]].squeeze()
                        loss = F.cross_entropy(y_pred, y_true)

                        loss.backward()
                        optimizer.step()
                        total_loss += loss.item()

                        # logger.debug(f"Batch loss: {loss.item():.2f}")
                else:
                    optimizer.zero_grad()
                    loss_ce, loss_pseudo = model(self.x, self.edge_index, self.labels, self.pseudo_labels, lp_pred_mat, self.train_mask, self.val_mask, self.test_mask)
                    loss = loss_ce + self.args.lamb * loss_pseudo
                    loss.backward()
                    optimizer.step()

                # Valid
                model.eval()
                if self.inference_loader:
                    total_edges = 0
                    x_tmp = deepcopy(x.cpu())
                    for i in range(self.args.num_layers):
                        xs = []
                        for batch_size, n_id, adj in self.inference_loader:
                            edge_index, _, size = adj.to(self.args.device)
                            total_edges += edge_index.size(1)
                            x = x_tmp[n_id].to(self.args.device)
                            x_target = x[: size[1]]
                            x = model.classifier.conv1[i]((x, x_target), edge_index)
                            # x = model.classifier((x, x_target), edge_index, layer_idx=i)
                            x = F.relu(x)
                            if i == self.args.num_layers - 1:
                                x = model.classifier.classifier(x)

                            xs.append(x.detach().cpu())
                        
                        x_tmp = torch.cat(xs, dim=0)
                    
                    output = x_tmp

                else:                    
                    embed = model.classifier1(self.x, self.edge_index, embed=True)
                    output = model.classifier2(embed, self.edge_index)

                acc_val, macro_F_val = utils.performance(output[self.val_mask], self.labels[self.val_mask], pre='valid', evaluator=self.evaluator)

                acc_vals.append(acc_val)

                if best_metric <= acc_val:
                    best_metric = acc_val
                    max_idx = acc_vals.index(max(acc_vals))
                    best_output = output[:]

                # Test
                acc_test, macro_F_test = utils.performance(output[self.test_mask], self.labels[self.test_mask], pre='test', evaluator=self.evaluator)

                test_results.append([acc_test, macro_F_test])
                best_test_result = test_results[max_idx]

                if epoch % self.args.print_result == 0:
                    __model_name = self.args.gnn + '-' + self.args.filling_method
                    _model_name = __model_name + f'_label_trick_mask_rate_{self.args.mask_rate}' if self.args.label_trick else __model_name
                    model_name = _model_name + f'_n_reuse_{self.args.n_reuse}' if self.args.n_reuse > 0 else _model_name
                    st = "[seed {}][{}-{}][{}][Epoch {}]".format(seed, self.args.dataset, self.args.missing_rate, model_name, epoch)
                    st += "[Val] ACC: {:.2f}, Macro-F1: {:.2f}|| ".format(acc_val, macro_F_val)
                    st += "[Test] ACC: {:.2f}, Macro-F1: {:.2f}\n".format(acc_test, macro_F_test)
                    st += "  [*Best Test Result*][Epoch {}] ACC: {:.2f}, Macro-F1: {:.2f}".format(max_idx, best_test_result[0], best_test_result[1])
                    print(st)
                      
                if (epoch - max_idx > self.args.patience) or (epoch+1 == self.args.epochs):
                    if epoch - max_idx > self.args.patience:
                        print("Early stop")
                    output = best_output
                    best_test_result[0], best_test_result[1] = utils.performance(output[self.test_mask], self.labels[self.test_mask], pre='test', evaluator=self.evaluator)
                    print("[Best Test Result] ACC: {:.2f}, Macro-F1: {:.2f}".format(best_test_result[0], best_test_result[1]))
                    torch.cuda.empty_cache()
                    break

            seed_result['acc'].append(float(best_test_result[0]))
            seed_result['macro_F'].append(float(best_test_result[1]))

        acc = seed_result['acc']
        f1 = seed_result['macro_F']

        print('[Averaged result] ACC: {:.2f}+{:.2f}, Macro-F: {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))
        print('{:.2f}+{:.2f} {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))

        logger.info('')
        logger.info(datetime.datetime.now())
        logger.info(file)
        logger.info(f'----------- missing rate: {self.args.missing_rate} -----------')
        logger.info('{:.2f}+{:.2f} {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))
        logger.info('{:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc)))
        logger.info('{:.2f}+{:.2f}'.format(np.mean(f1), np.std(f1)))
        logger.info(self.args)
        logger.info(f'=================================')
        
        # print(self.args)


class modeler(nn.Module):
    def __init__(self, args):
        super(modeler, self).__init__()
        self.args = args

        ## Model Selection ##
        if args.gnn == 'GCN':
            classifier1 = GCN(1, args.n_feat, args.n_hid, args.n_class, normalize=True, is_add_self_loops=False)
            classifier2 = GCN(1, args.n_hid, args.n_hid, args.n_class, normalize=True, is_add_self_loops=False)
        elif args.gnn == 'GAT':
            classifier = GAT(args.n_layer, args.n_feat, args.n_hid, args.n_class, args.n_head, is_add_self_loops=False)
        elif args.gnn == "SAGE":
            classifier = SAGE(args.n_layer, args.n_feat, args.n_hid, args.n_class)
        elif args.gnn == "SGC":
            classifier = SGC(args.n_feat, args.n_class, args.n_layer, is_add_self_loops=False)
        else:
            raise NotImplementedError("Not Implemented Architecture!")        
        self.classifier1 = classifier1
        self.classifier2 = classifier2

    def forward(self, x, edge_index, labels, pseudo_labels, weight_mask, idx_train, idx_val=None, idx_test=None):
        embed = self.classifier1(x, edge_index, embed=True)
        if self.args.lamb != 0.0:
            if self.args.scaled:
                label_mat = torch.nn.functional.one_hot(pseudo_labels).T + 0.0
                centroids = (torch.mm(torch.mm(label_mat, torch.diag(weight_mask)), embed))
                pseudocon_loss = self.pseudocon_loss(centroids, scaled=True)
            else:
                pseudocon_loss = self.pseudocon_loss(embed, pseudo_labels, weight_mask=weight_mask)            
        else:
            pseudocon_loss = 0.0
        
        output = self.classifier2(embed, edge_index)

        if 'OGBN' in self.args.dataset:
            labels = labels.squeeze(1)
        loss_nodeclassification = F.cross_entropy(output[idx_train], labels[idx_train])

        return loss_nodeclassification, pseudocon_loss


    def pseudocon_loss(self, features, labels=None, mask=None, temp=0.07, base_temp=0.07, weight_mask=None, scaled=False):
        # Normalize
        features = F.normalize(features, dim=-1)
        batch_size = features.shape[0]
        if scaled:
            mask = torch.eye(batch_size, dtype=torch.float32).to(self.args.device)
        else:
            labels = labels.contiguous().view(-1, 1)
            mask = torch.eq(labels, labels.T).float()

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.args.temp)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(self.args.device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask 
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        if scaled:
            mean_log_prob_pos = log_prob.sum(1)
        else:
            # row_mask = mask.sum(1) != 0
            # mean_log_prob_pos = (weight_mask[row_mask].unsqueeze(1) * mask[row_mask] * log_prob[row_mask]).sum(1) / mask[row_mask].sum(1)
            mean_log_prob_pos = (weight_mask * mask * log_prob).sum(1) / mask.sum(1)

        loss = - (temp / base_temp) * mean_log_prob_pos
        loss = loss.mean()

        return loss