from .base_learner import BaseOCGLearner
from .utils import ReservoirSamplingBuffer, ByClassReservoirSamplingBuffer

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        if args.dataset == 'Elliptic':
            self.buffer = ByClassReservoirSamplingBuffer(args.memory_budget, args.n_cls)
        else:
            self.buffer = ReservoirSamplingBuffer(args.memory_budget)
        self.memory_proportion = args.er_args['memory_proportion']

    def compute_auxiliary_loss(self, input_data, labels, logits):
        if len(self.buffer) > 0:
            n_samples = min(len(labels)*self.memory_proportion, len(self.buffer))
            aux_features, aux_labels = self.buffer.sample(n_samples)
            if self.center_features:
                aux_features = aux_features - self.feature_mean
            output = self.net(aux_features)[:, self.output_mask]
            loss_aux = self.ce(output, aux_labels)
            return loss_aux

    def after_passes(self, input_data, labels):
        if isinstance(input_data, list):
            input_data = input_data[-1].dstdata['feat']
        if self.center_features:
            input_data = input_data + self.feature_mean
        self.buffer.update(input_data, labels)
