import torch
import dgl
from .base_learner import BaseOCGLearner
from .utils import ReservoirSSM, store_grad, overwrite_grad, ByClassReservoirSSM

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        if args.dataset == 'Elliptic':
            self.buffer = ByClassReservoirSSM(args.memory_budget, args.n_cls, args.ssmagem_args['nei_budget'])
        else:
            self.buffer = ReservoirSSM(args.memory_budget, args.ssmagem_args['nei_budget'])
        self.memory_proportion = args.ssmagem_args['memory_proportion']

        self.grad_dims = []
        for param in self.net.parameters():
            self.grad_dims.append(param.data.numel())
        self.grads = torch.Tensor(sum(self.grad_dims), 2).cuda()

    def before_training(self, input_data, labels):
        if len(self.buffer) > 0:
            n_samples = min(len(labels)*self.memory_proportion, len(self.buffer))
            aux_subgraphs, aux_labels = self.buffer.sample(n_samples)
            batched_graph = dgl.batch(aux_subgraphs)
            target_node_ids = torch.nonzero(batched_graph.ndata['target'], as_tuple=False).squeeze()
            _, _, blocks = dgl.dataloading.MultiLayerFullNeighborSampler(2).sample_blocks(batched_graph, target_node_ids)
            self.net.zero_grad()
            output = self.net(blocks)[:, self.output_mask]
            loss_aux = self.ce(output, aux_labels)
            loss_aux.backward()
            store_grad(self.net.parameters, self.grads, self.grad_dims, 0)

    def after_loss(self, input_data, labels):
        if len(self.buffer) > 0:
            store_grad(self.net.parameters, self.grads, self.grad_dims, 1)
            dotp = torch.mm(self.grads[:, 1].unsqueeze(0),
                            self.grads.index_select(1, torch.LongTensor([0]).to(self.grads.device))) 
            if dotp < 0:
                dotp_ref = torch.mm(self.grads[:, 0].unsqueeze(0),
                            self.grads.index_select(1, torch.LongTensor([0]).to(self.grads.device))) 
                projected_grad = (self.grads[:, 1] - (dotp / dotp_ref) * self.grads[:, 0]).squeeze()
                overwrite_grad(self.net.parameters, projected_grad, self.grad_dims)

    def after_passes(self, input_data, labels):
        self.buffer.update(input_data, labels)
