from models.utils.continual_model import ContinualModel
from models.optimizers import get_optimizer
from models.optimizers.lr_scheduler import LR_Scheduler
import torch

class Finetune(ContinualModel):
    NAME = 'finetune'
    #COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
    def __init__(self, backbone, loss, args, len_train_loader, transform):
        super(Finetune, self).__init__(backbone, loss, args, len_train_loader, transform)

    def set_task(self, dummy, task_id):
        task_init_lr = self.args.init_lr * self.args.train.task_lr_decay**task_id                        
        self.lr_scheduler = LR_Scheduler(
            optimizer=self.opt,
            warmup_epochs=self.args.train.warmup_epochs,
            warmup_lr=self.args.train.warmup_lr*self.args.batch_size/256,
            num_epochs=self.args.num_epochs,
            base_lr=task_init_lr*self.args.batch_size/256,
            final_lr=task_init_lr*self.args.final_lr_decay*self.args.batch_size/256,
            iter_per_epoch=self.len_train_lodaer,
            constant_predictor_lr=True # see the end of section 4.2 predictor
        )        
        
        if task_id > 0 and self.args.reinit_opt_per_task:            
            self.opt = get_optimizer(
                self.args.optimizer, self.net,
                lr=task_init_lr*self.args.batch_size/256,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )

    def end_task(self, dataset):
        self.lr_scheduler.reset()
        
    def observe(self, inputs, labels, _, task_id, eval=False):
        self.opt.zero_grad()
        labels = labels.to(self.device)
        # select samples (random)
        size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
        pick = torch.randperm(len(inputs))[:size]
        inputs, labels = inputs[pick], labels[pick]
        outputs = self.net(inputs.to(self.device))
        
        if eval:
            return outputs
        else:
            masked_output = self.task_masking(outputs, labels, task_id)
            loss = self.loss(masked_output, labels).mean()
            data_dict = {'loss': loss.item()}
            data_dict['penalty'] = 0.0
            data_dict['ce_loss'] = data_dict['loss']
            
            _, pred = torch.max(masked_output.data, 1)
            corrects = torch.sum(pred == labels).item()
            total = labels.shape[0]           
            data_dict['train_acc'] = corrects/total

            loss.backward()
            self.opt.step()
            self.lr_scheduler.step()
            data_dict.update({'lr': self.lr_scheduler.get_lr()})
            return data_dict

    def task_masking(self, outputs, labels, task_id, is_replay=False):
        nc = self.args.dataset.num_classes_per_task
        if is_replay:
            raise NotImplementedError()
        else:
            t = task_id
            offset1 = int(t * nc)
            offset2 = int((t+1) * nc)
            if offset1 > 0:
                outputs[:, :offset1].data.fill_(-10e10)
            if offset2 < self.args.dataset.num_total_classes:
                outputs[:, offset2:self.args.dataset.num_total_classes].data.fill_(-10e10)
            return outputs
