import torch
from .base_learner import BaseOCGLearner

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)

        assert(args.backbone == 'GCN')
        self.fisher_loss = []
        self.fisher_att = []
        self.optpar = []
        self.lambda_l = args.twp_args['lambda_l']
        self.lambda_t = args.twp_args['lambda_t']
        self.beta = args.twp_args['beta']

    def compute_auxiliary_loss(self, input_data, labels, logits):
        if len(self.fisher_loss) != 0:
            aux_loss = 0
            for i, p in enumerate(self.net.parameters()):
                l = self.lambda_l * self.fisher_loss[i] + self.lambda_t * self.fisher_att[i]
                l = l * (p - self.optpar[i]).pow(2) + self.beta * l
                aux_loss += l.sum()
            return aux_loss

    def after_passes(self, input_data, labels):
        self.net.zero_grad()
        new_fisher_loss = []
        new_fisher_att = []
        self.optpar = []
        output, elist = self.net(input_data, twp=True)
        self.ce(output[:, self.output_mask], labels).backward(retain_graph=True)

        for p in self.net.parameters():
            self.optpar.append(p.data.clone())
            new_fisher_loss.append(p.grad.data.clone().pow(2))

        eloss = torch.norm(elist[0])
        eloss.backward()
        for p in self.net.parameters():
            new_fisher_att.append(p.grad.data.clone().pow(2))

        if len(self.fisher_loss) != 0:
            for i, f in enumerate(new_fisher_loss):
                self.fisher_loss[i] = (self.fisher_loss[i] * self.n_seen_examples + new_fisher_loss[i]*len(labels)) / (self.n_seen_examples + len(labels))
            for i, f in enumerate(new_fisher_att):
                self.fisher_att[i] = (self.fisher_att[i] * self.n_seen_examples + new_fisher_att[i]*len(labels)) / (self.n_seen_examples + len(labels))
            self.n_seen_examples += len(labels)
        else:
            self.fisher_loss = new_fisher_loss
            self.fisher_att = new_fisher_att
            self.n_seen_examples = len(labels)
