import torch
import copy
import os
from torch.autograd import Variable
import torch.nn.functional as F
from dgl.nn.pytorch import edge_softmax, GATConv
import torch.nn as nn
import dgl

def MultiClassCrossEntropy(logits, labels, T):
    labels = Variable(labels.data, requires_grad=False).cuda()
    outputs = torch.log_softmax(logits/T, dim=1)   # compute the log of softmax values
    labels = torch.softmax(labels/T, dim=1)
    outputs = torch.sum(outputs * labels, dim=1, keepdim=False)
    outputs = -torch.mean(outputs, dim=0, keepdim=False)
    return outputs


def kaiming_normal_init(m):
	if isinstance(m, torch.nn.Conv2d):
		torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
	elif isinstance(m, torch.nn.Linear):
		torch.nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
        
class NET(torch.nn.Module):
    """
        LwF baseline for NCGL tasks

        :param model: The backbone GNNs, e.g. GCN, GAT, GIN, etc.
        :param args: The arguments containing the configurations of the experiments including the training parameters like the learning rate, the setting confugurations like class-IL and task-IL, etc. These arguments are initialized in the train.py file and can be specified by the users upon running the code.

        """
    def __init__(self,
                 model,
                 args):
        super(NET, self).__init__()

        self.args = args
        self.activation = F.elu

        # setup network
        self.net = model
        self.net.apply(kaiming_normal_init)                

        # setup optimizer
        self.opt = torch.optim.Adam(self.net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        # setup losses
        self.ce = torch.nn.functional.cross_entropy

        self.prev_model = None
        self.seen_classes = []
 
    def forward(self, g, features):
        
        h = features
        h = self.feature_extractor(g, h)[0]
        if len(h.shape)==3:
            h = h.flatten(1)
        h = self.activation(h)
        h = self.gat(g, h)[0]
        if len(h.shape)==3:
            h = h.mean(1)
        return h
                
    def observe_minibatch(self, args, g, features, labels, train_ids, ids_per_cls):
        """
                The method for learning the given tasks under the class-IL setting.

                :param args: Same as the args in __init__().
                :param g: The graph of the current task.
                :param features: Node features of the current task.
                :param labels: Labels of the nodes in the current task.
                :param train_ids: The indices of the nodes participating in the training.
                :param ids_per_cls: Indices of the nodes in each class (not in use in the current baseline).

                """
        self.net.train()
        for label in labels.unique():
            if label not in self.seen_classes:
                self.seen_classes.append(label)
 
        self.net.zero_grad()
        offset1, offset2 = 0, max(self.seen_classes)+1
        nb_sampler = dgl.dataloading.NeighborSampler(args.n_nbs_sample) if args.sample_nbs else dgl.dataloading.MultiLayerFullNeighborSampler(len(self.net.gat_layers))
        _, _, blocks = nb_sampler.sample_blocks(g, torch.tensor(train_ids).to(device='cuda:{}'.format(args.gpu)))
        input_features = blocks[0].srcdata['feat']
        logits, _ = self.net.forward_batch(blocks, input_features)
        if isinstance(logits,tuple):
            logits = logits[0]
        output_labels = labels[train_ids]

        if args.cls_balance:
            n_per_cls = [(output_labels == j).sum() for j in range(args.n_cls)]
            loss_w_ = [1. / max(i, 1) for i in n_per_cls]  # weight to balance the loss of different class
        else:
            loss_w_ = [1. for i in range(args.n_cls)]
        loss_w_ = torch.tensor(loss_w_).to(device='cuda:{}'.format(args.gpu))
        loss = self.ce(logits[:, offset1:offset2], output_labels, weight=loss_w_[offset1: offset2])

        if len(self.seen_classes) > 2:
            target, _ = self.prev_model.forward_batch(blocks, input_features)
            if isinstance(target, tuple):
                target = target[0]
            o1, o2 = 0, len(self.seen_classes) - 2
            logits_dist = logits[:,o1:o2]
            dist_target = target[:,o1:o2]
            dist_loss = MultiClassCrossEntropy(logits_dist, dist_target, args.lwf_args['T'])
            loss = loss + args.lwf_args['lambda_dist']*dist_loss
        
        loss.backward()
        self.opt.step()
