import torch
from .base_learner import BaseOCGLearner
from .utils import ReservoirSamplingBuffer, ByClassReservoirSamplingBuffer, store_grad, overwrite_grad

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        if args.dataset == 'Elliptic':
            self.buffer = ByClassReservoirSamplingBuffer(args.memory_budget, args.n_cls)
        else:
            self.buffer = ReservoirSamplingBuffer(args.memory_budget)
        self.memory_proportion = args.agem_args['memory_proportion']

        self.grad_dims = []
        for param in self.net.parameters():
            self.grad_dims.append(param.data.numel())
        self.grads = torch.Tensor(sum(self.grad_dims), 2).cuda()

    def before_training(self, input_data, labels):
        if len(self.buffer) > 0:
            n_samples = min(len(labels)*self.memory_proportion, len(self.buffer))
            aux_features, aux_labels = self.buffer.sample(n_samples)
            if self.center_features:
                aux_features = aux_features - self.feature_mean
            self.net.zero_grad()
            output = self.net(aux_features)[:, self.output_mask]
            loss_aux = self.ce(output, aux_labels)
            loss_aux.backward()
            store_grad(self.net.parameters, self.grads, self.grad_dims, 0)

    def after_loss(self, input_data, labels):
        if len(self.buffer) > 0:
            store_grad(self.net.parameters, self.grads, self.grad_dims, 1)
            dotp = torch.mm(self.grads[:, 1].unsqueeze(0),
                            self.grads.index_select(1, torch.LongTensor([0]).to(self.grads.device))) 
            if dotp < 0:
                dotp_ref = torch.mm(self.grads[:, 0].unsqueeze(0),
                            self.grads.index_select(1, torch.LongTensor([0]).to(self.grads.device))) 
                projected_grad = (self.grads[:, 1] - (dotp / dotp_ref) * self.grads[:, 0]).squeeze()
                overwrite_grad(self.net.parameters, projected_grad, self.grad_dims)

    def after_passes(self, input_data, labels):
        if isinstance(input_data, list):
            input_data = input_data[-1].dstdata['feat']
        if self.center_features:
            input_data = input_data + self.feature_mean
        self.buffer.update(input_data, labels)
