from .utils.continual_model import ContinualModel
from .optimizers import LR_Scheduler, get_apd_optimizer

import torch
from .utils.continual_model import ContinualModel

def get_base_model_size(name):
    if 'resnet18' in name:
        return 11220132
    elif 'resnet18' in name:
        return 23705252
    elif 'vit-t4' in name:
        return 9798756
    elif 'vit-s4' in name:
        NotImplementedError
    else:
        NotImplementedError('%s is not correct network'%name)


def count_parameters(parameters, thr=None):
    if thr:
        return sum([torch.sum(torch.abs(p)>thr).item() for p in parameters])
    else:
        return sum(p.numel() for p in parameters if p.requires_grad)

class Apd(ContinualModel):
    NAME = 'apd'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, len_train_loader, transform):
        super(Apd, self).__init__(backbone, loss, args, len_train_loader, transform)
        #self.x_shape = None
        self.c = self.args.reg_lambda


    def set_task(self, dummy, task_id):
        if task_id == 0:
            _ = self.net(dummy.to(self.device), '%s'%(int(self.args.dataset.num_total_classes/self.args.dataset.num_classes_per_task)))

        # Set new optimizer for current task
        self.opt = get_apd_optimizer(
            self.args.train.optimizer.name, self.net, task_id,
            lr=self.args.train.base_lr*self.args.train.batch_size/256,
            momentum=self.args.train.optimizer.momentum,
            weight_decay=self.args.train.optimizer.weight_decay,
            base_parameters=self.opt.param_groups[0]['params']
            )
        self.lr_scheduler = LR_Scheduler(
            optimizer=self.opt,
            warmup_epochs=self.args.train.warmup_epochs,
            warmup_lr=self.args.train.warmup_lr*self.args.train.batch_size/256,
            num_epochs=self.args.train.num_epochs,
            base_lr=self.args.train.base_lr*self.args.train.batch_size/256,
            final_lr=self.args.train.final_lr*self.args.train.batch_size/256,
            iter_per_epoch=self.len_train_lodaer,
            constant_predictor_lr=True # see the end of section 4.2 predictor
            )
        self.opt.zero_grad()

    def end_task(self, task_id):
        # instantiate new column
        qe = count_parameters(self.opt.param_groups[0]['params'])
        _wmsk = self.net.key_weights['msk']
        _wtad = self.net.key_weights['tad']
        seen_wmsk = [_wmsk[_k].data for _k in _wmsk.keys() if task_id >= int(_k.split('_t')[-1])]
        seen_wtad = [_wtad[_k].data for _k in _wtad.keys() if task_id >= int(_k.split('_t')[-1])]

        qe += count_parameters(list(self.net.key_weights['tsh'].values()))
        qe += count_parameters(seen_wmsk)
        qe += count_parameters(seen_wtad, self.c)
        print('Model numel: ', qe, 'Model size: %.4f MB (x%.4f)'%(qe * 0.000001, qe/get_base_model_size(self.args.model.backbone)))

        tsh = self.net.key_weights['tsh']
        for tid in range(task_id+1):
            for tsh_key in tsh.keys():
                _key = tsh_key.replace('shared', 't%s'%tid)
                self.net.key_weights['tmp_sol'][_key] = torch.einsum('i,i...->i...',
                                                self.net.key_weights['msk'][_key],self.net.key_weights['tsh'][tsh_key]) + \
                                                self.net.key_weights['tad'][_key]

        self.lr_scheduler.reset()

    #def observe(self, inputs, labels, notaug_inputs):
    def observe(self, inputs, labels, notaug_inputs, task_id, eval=False):
        """
        if self.x_shape is None:
            self.x_shape = inputs.shape
        """
        self.opt.zero_grad()
        outputs = self.net(inputs.to(self.device), task_id)
        if eval:
            return outputs
        else:
            masked_output = self.task_masking(outputs, labels, task_id)
            penalty = self.apd_loss(task_id)
            ce_loss = self.loss(masked_output, labels.to(self.device))

            loss = ce_loss + penalty
            data_dict = {'loss': loss.item(), 'ce_loss': ce_loss.item(), 'penalty': penalty.item()}
            loss.backward()
            self.opt.step()
            self.lr_scheduler.step()
            #'forward->X, masking: ->X, apd_loss: @@@@, ce_loss: ->X, backward: @@@@@, optstep: ->X, lrstep: ->X'
            #print('\nforward: %.4f, masking: %.4f, apd_loss: %.4f, ce_loss: %.4f, backward: %.4f, optstep: %.4f, lrstep: %.4f\n'%(
            #            time1-time0,	time2-time1,	time3-time2,	time4-time3,	time6-time5,	time7-time6,	time8-time7))
            data_dict.update({'lr': self.lr_scheduler.get_lr()})
            return data_dict

    def apd_loss(self, task_id):
        penalty = 0.
        #time0 = time.time()
        tsh = self.net.key_weights['tsh']
        msk = self.net.key_weights['msk']
        tad = self.net.key_weights['tad']
        tmp_sol = self.net.key_weights['tmp_sol']
        #time1 = time.time()
        # L1_loss for task adaptive parameters
        penalty += self.l1_loss(tad)
        #time2 = time.time()
        # Retroactive loss
        for tid in range(task_id):
            for tsh_key in tsh.keys():
                # 'layer1_0__conv1_weights_tsh' -> 'layer1_0__conv2_weights_msk_t0'
                _key = tsh_key.replace('shared', 't%d'%tid)
                penalty += self.args.ret_hyp * torch.norm(tmp_sol[_key].detach().data -
                                    (torch.einsum('i,i...->i...', msk[_key], tsh[tsh_key]) + tad[_key]))
        #time3 = time.time()
        #print('get_ws: %.4f, l1_loss: %.4f, apd_loss: %.4f'%(time1-time0, time2-time1, time3-time2))
        return penalty

    def l1_loss(self, params):
        penalty = 0.
        for pkey in params.keys():
            penalty += self.c * torch.mean(torch.abs(params[pkey]))
        return penalty

    def task_masking(self, outputs, labels, task_id, is_replay=False):
        nc = self.args.dataset.num_classes_per_task
        if is_replay:
            NotImplementedError()
        else:
            t = task_id
            offset1 = int(t * nc)
            offset2 = int((t+1) * nc)
            if offset1 > 0:
                outputs[:, :offset1].data.fill_(-10e10)
            if offset2 < self.args.dataset.num_total_classes:
                outputs[:, offset2:self.args.dataset.num_total_classes].data.fill_(-10e10)
            return outputs
