import torch
import dgl
from .base_learner import BaseOCGLearner
from .utils import ReservoirSSM, ByClassReservoirSSM

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        if args.dataset == 'Elliptic':
            self.buffer = ByClassReservoirSSM(args.memory_budget, args.n_cls, args.ssmer_args['nei_budget'])
        else:
            self.buffer = ReservoirSSM(args.memory_budget, args.ssmer_args['nei_budget'])
        self.memory_proportion = args.ssmer_args['memory_proportion']

    def compute_auxiliary_loss(self, input_data, labels, logits):
        if len(self.buffer) > 0:
            n_samples = min(len(labels)*self.memory_proportion, len(self.buffer))
            aux_subgraphs, aux_labels = self.buffer.sample(n_samples)
            batched_graph = dgl.batch(aux_subgraphs)
            target_node_ids = torch.nonzero(batched_graph.ndata['target'], as_tuple=False).squeeze()
            _, _, blocks = dgl.dataloading.MultiLayerFullNeighborSampler(2).sample_blocks(batched_graph, target_node_ids)
            output = self.net(blocks)[:, self.output_mask]
            loss_aux = self.ce(output, aux_labels)
            return loss_aux

    def after_passes(self, input_data, labels):
        self.buffer.update(input_data, labels)
