import torch
from .base_learner import BaseOCGLearner
from .utils import MultiClassCrossEntropy, kaiming_normal_init
from Backbones.model_factory import get_classifier

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)
        self.seen_classes = []
        self.prev_model = None
        self.get_model = lambda: get_classifier(args).cuda(args.gpu)
        self.batch_count = 0
        self.net.apply(kaiming_normal_init) 

        self.T = args.lwf_args['T']
        self.lambda_dist = args.lwf_args['lambda_dist']
        self.save_every = args.lwf_args['save_every']

    def compute_auxiliary_loss(self, input_data, labels, logits):
        if self.prev_model is not None:
            dist_mask = torch.ones(logits.shape[1], dtype=torch.bool)
            dist_mask[labels.unique()] = False
            if dist_mask.any():
                target = self.prev_model(input_data)[:, self.output_mask]
                logits_dist = logits[:,dist_mask]
                dist_target = target[:,dist_mask]
                dist_loss = MultiClassCrossEntropy(logits_dist, dist_target, self.T)
                return self.lambda_dist * dist_loss

    def after_passes(self, input_data, labels):
        self.batch_count += 1
        if self.batch_count % self.save_every == 0:
            self.prev_model = self.get_model()
            self.prev_model.load_state_dict(self.net.state_dict())
            self.prev_model.requires_grad_(False)
