from .base_learner import BaseOCGLearner

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        self.reg = args.mas_args['memory_strength']
        self.importance = []
        self.optpar = []

    def compute_auxiliary_loss(self, input_data, labels, logits):
        if len(self.importance) != 0:
            aux_loss = 0
            for i, p in enumerate(self.net.parameters()):
                l = self.reg * self.importance[i]
                l = l * (p - self.optpar[i]).pow(2)
                aux_loss += l.sum()
            return aux_loss

    def after_passes(self, input_data, labels):
        self.optpar = []
        new_importance = []
        self.net.zero_grad()
        output = self.net(input_data)[:, self.output_mask].pow(2).mean()
        self.net.zero_grad()
        output.backward()

        for p in self.net.parameters():
            self.optpar.append(p.data.clone())
            new_importance.append(p.grad.data.clone().pow(2))

        if self.importance:
            for i, s in enumerate(new_importance):
                self.importance[i] = (self.importance[i] * self.n_seen_examples + s * len(labels)) / (self.n_seen_examples + len(labels))
            self.n_seen_examples += len(labels)
        else:
            self.importance = new_importance
            self.n_seen_examples = len(labels)
