import pickle
import gc
import random
import torch
import torch.nn as nn
import torch.optim as optim
import copy

import numpy as np
#import quadprog
import dgl
from .agem_utils import store_grad, overwrite_grad

class NET(nn.Module):
    """
    A-GEM baseline for NCGL tasks

    :param model: The backbone GNNs, e.g. GCN, GAT, GIN, etc.
    :param args: The arguments containing the configurations of the experiments including the training parameters like the learning rate, the setting confugurations like class-IL and task-IL, etc. These arguments are initialized in the train.py file and can be specified by the users upon running the code.

    """
    def __init__(self,
                model,
                args):
        super(NET, self).__init__()        
        self.net = model

        self.ce = torch.nn.functional.cross_entropy
        self.opt = optim.Adam(self.net.parameters(), lr = args.lr, weight_decay = args.weight_decay)
        
        self.budget = int(args.agem_args['n_memories'])
        self.memory_proportion = int(args.agem_args['memory_proportion'])
        self.buffer_node_ids = []
        self.aux_g = None
        self.n_seen_examples = 0
        self.epochs = 0

        # allocate episodic memory
        # for semi-supervised data, it will store the training mask for every old tasks
        self.memory_data = []
        
        # allocate temporary synaptic memory
        self.grad_dims = []
        for param in self.net.parameters():
            self.grad_dims.append(param.data.numel())
        self.grads = torch.Tensor(sum(self.grad_dims), 2).cuda()
        
        # allocate counters
        self.current_task = -1
        self.mem_cnt = 0
        self.old_dataloaders = {}
        self.seen_classes = []
    
    def forward(self, features):
        output = self.net(features)
        return output

    def observe_minibatch(self, args, g, features, labels, train_ids, ids_per_cls):
        """
                The method for learning the given tasks under the class-IL setting.

                :param args: Same as the args in __init__().
                :param g: The graph of the current task.
                :param features: Node features of the current task.
                :param labels: Labels of the nodes in the current task.
                :param train_ids: The indices of the nodes participating in the training.
                :param ids_per_cls: Indices of the nodes in each class (not in use in the current baseline).

                """

        for label in labels.unique():
            if label not in self.seen_classes:
                self.seen_classes.append(label)

        self.epochs += 1
        last_epoch = self.epochs % args.epochs

        n_nodes = len(train_ids)
        buffer_size = len(self.buffer_node_ids)

        offset1, offset2 = 0, max(self.seen_classes)+1

        # compute gradient on previous tasks
        if buffer_size > 0:
            sampled_mask = torch.zeros(buffer_size, dtype=torch.bool)
            n_samples = min(args.batch_size*self.memory_proportion, buffer_size)
            sampled_mask[torch.randperm(buffer_size)[:n_samples]] = True
            # calculate auxiliary loss based on replay
            self.net.zero_grad()
            output, _ = self.net(self.aux_g, self.aux_features)
            old_task_loss = self.ce(output[sampled_mask, offset1:offset2], self.aux_labels[sampled_mask])
            old_task_loss.backward()
            store_grad(self.net.parameters, self.grads, self.grad_dims, 0)

        # compute the grad on the current minibatch
        self.net.zero_grad()
        nb_sampler = dgl.dataloading.NeighborSampler(args.n_nbs_sample) if args.sample_nbs else dgl.dataloading.MultiLayerFullNeighborSampler(len(self.net.gat_layers))
        _, _, blocks = nb_sampler.sample_blocks(g, torch.tensor(train_ids).to(device='cuda:{}'.format(args.gpu)))
        input_features = blocks[0].srcdata['feat']
        output, _ = self.net.forward_batch(blocks, input_features)
        output_labels = labels[train_ids]
        if args.cls_balance:
            n_per_cls = [(output_labels == j).sum() for j in range(args.n_cls)]
            loss_w_ = [1. / max(i, 1) for i in n_per_cls]  # weight to balance the loss of different class
        else:
            loss_w_ = [1. for i in range(args.n_cls)]
        loss_w_ = torch.tensor(loss_w_).to(device='cuda:{}'.format(args.gpu))
        loss = self.ce(output[:, offset1:offset2], output_labels, weight=loss_w_[offset1: offset2])

        loss.backward()

        # check if gradient violates constraints
        if buffer_size > 0:
            # copy gradient
            store_grad(self.net.parameters, self.grads, self.grad_dims, 1)
            # compute dot product of self.grads[:, 0] and self.grads[:, 1]
            dotp = torch.mm(self.grads[:, 1].unsqueeze(0),
                            self.grads.index_select(1, torch.LongTensor([0]).to(self.grads.device))) 
            if dotp < 0:
                dotp_ref = torch.mm(self.grads[:, 0].unsqueeze(0),
                            self.grads.index_select(1, torch.LongTensor([0]).to(self.grads.device))) 
                projected_grad = (self.grads[:, 1] - (dotp / dotp_ref) * self.grads[:, 0]).squeeze()
                # copy gradients back
                overwrite_grad(self.net.parameters, projected_grad,
                               self.grad_dims)
        
        self.opt.step()

        if last_epoch == 0:
            # perform reservoir sampling
            place_left = max(0, self.budget - len(self.buffer_node_ids))
            if place_left:
                offset = min(place_left, n_nodes)
                self.buffer_node_ids.extend(train_ids[:offset])
                if offset < n_nodes:
                    for i in range(offset, n_nodes):
                        j = torch.randint(0, self.n_seen_examples + i, (1,))
                        if j < self.budget:
                            self.buffer_node_ids[j] = train_ids[i]
            else:
                for i in range(n_nodes):
                    j = torch.randint(0, self.n_seen_examples + i, (1,))
                    if j < self.budget:
                        self.buffer_node_ids[j] = train_ids[i]
            self.n_seen_examples += n_nodes
            subg = dgl.node_subgraph(g, self.buffer_node_ids, store_ids=True)
            n_edges = subg.edges()[0].shape[0]
            subg.remove_edges(list(range(n_edges)))
            subg = dgl.add_self_loop(subg)
            self.aux_g = subg.to(device='cuda:{}'.format(features.get_device()))
            self.aux_features, self.aux_labels = self.aux_g.srcdata['feat'], self.aux_g.dstdata['label'].squeeze()
