import torch
import dgl

class NET(torch.nn.Module):
    """
        EWC baseline for NCGL tasks

        :param model: The backbone GNNs, e.g. GCN, GAT, GIN, etc.
        :param args: The arguments containing the configurations of the experiments including the training parameters like the learning rate, the setting confugurations like class-IL and task-IL, etc. These arguments are initialized in the train.py file and can be specified by the users upon running the code.

        """
    def __init__(self,
                 model,
                 args):
        super(NET, self).__init__()
        self.reg = args.ewc_args['memory_strength']

        # setup network
        self.net = model

        # setup optimizer
        self.opt = torch.optim.Adam(self.net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        # setup losses
        self.ce = torch.nn.functional.cross_entropy

        self.fisher = []
        self.optpar = []
        self.epochs = 0
        self.seen_classes = []


    def forward(self, features):
        output = self.net(features)
        return output
       
    def observe_minibatch(self, args, g, features, labels, train_ids, ids_per_cls):
        """
                The method for learning the given tasks under the class-IL setting.

                :param args: Same as the args in __init__().
                :param g: The graph of the current task.
                :param features: Node features of the current task.
                :param labels: Labels of the nodes in the current task.
                :param train_ids: The indices of the nodes participating in the training.
                :param ids_per_cls: Indices of the nodes in each class (not in use in the current baseline).
                
                """
        for label in labels.unique():
            if label not in self.seen_classes:
                self.seen_classes.append(label)

        self.epochs += 1
        last_epoch = self.epochs % args.epochs
        self.net.train()
        n_new_examples = len(train_ids)

        self.net.zero_grad()
        offset1, offset2 = 0, max(self.seen_classes)+1
        nb_sampler = dgl.dataloading.NeighborSampler(args.n_nbs_sample) if args.sample_nbs else dgl.dataloading.MultiLayerFullNeighborSampler(len(self.net.gat_layers))
        _, _, blocks = nb_sampler.sample_blocks(g, torch.tensor(train_ids).to(device='cuda:{}'.format(args.gpu)))
        input_features = blocks[0].srcdata['feat']
        output, _ = self.net.forward_batch(blocks, input_features)
        output_labels = labels[train_ids]
        if isinstance(output,tuple):
            output = output[0]

        if args.cls_balance:
            n_per_cls = [(output_labels == j).sum() for j in range(args.n_cls)]
            loss_w_ = [1. / max(i, 1) for i in n_per_cls]  # weight to balance the loss of different class
        else:
            loss_w_ = [1. for i in range(args.n_cls)]
        loss_w_ = torch.tensor(loss_w_).to(device='cuda:{}'.format(args.gpu))
        loss = self.ce(output[:, offset1:offset2], output_labels, weight=loss_w_[offset1: offset2])

        if len(self.fisher) != 0:
            for i, p in enumerate(self.net.parameters()):
                l = self.reg * self.fisher[i]
                l = l * (p - self.optpar[i]).pow(2)
                loss += l.sum()
        loss.backward()
        self.opt.step()

        if last_epoch == 0:
            self.optpar = []
            new_fisher = []
            self.net.zero_grad()
            output, _ = self.net.forward_batch(blocks, input_features)
            if isinstance(output, tuple):
                output = output[0]
            self.ce(output[:, offset1:offset2], output_labels, weight=loss_w_[offset1: offset2]).backward()

            for p in self.net.parameters():
                pd = p.data.clone()
                pg = p.grad.data.clone().pow(2)
                new_fisher.append(pg)
                self.optpar.append(pd)

            if len(self.fisher) != 0:
                for i, f in enumerate(new_fisher):
                    self.fisher[i] = (self.fisher[i] * self.n_seen_examples + new_fisher[i]*n_new_examples) / (
                                self.n_seen_examples + n_new_examples)
                self.n_seen_examples += n_new_examples
            else:
                for i, f in enumerate(new_fisher):
                    self.fisher.append(new_fisher[i])
                self.n_seen_examples = n_new_examples

        return loss
        