from .base_learner import BaseOCGLearner

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        self.reg = args.ewc_args['memory_strength']
        self.fisher = []
        self.optpar = []

    def compute_auxiliary_loss(self, input_data, labels, logits):
        if len(self.fisher) != 0:
            aux_loss = 0
            for i, p in enumerate(self.net.parameters()):
                l = self.reg * self.fisher[i]
                l = l * (p - self.optpar[i]).pow(2)
                aux_loss += l.sum()
            return aux_loss

    def after_passes(self, input_data, labels):
        self.optpar = []
        new_fisher = []
        self.net.zero_grad()
        output = self.net(input_data)[:, self.output_mask]
        self.ce(output, labels).backward()

        for p in self.net.parameters():
            self.optpar.append(p.data.clone())
            new_fisher.append(p.grad.data.clone().pow(2))

        if len(self.fisher) != 0:
            for i, f in enumerate(new_fisher):
                self.fisher[i] = (self.fisher[i] * self.n_seen_examples + f * len(labels)) / (self.n_seen_examples + len(labels))
            self.n_seen_examples += len(labels)
        else:
            self.fisher = new_fisher
            self.n_seen_examples = len(labels)
