import torch.nn.functional as F
import time
import torch
from tqdm import tqdm


def get_regularized_loss(criterion, det_score, det_target, model,reg_lambda,device):
    det_loss = criterion['MCE'](det_score, det_target)
    det_loss = det_loss.to(device)
    mas_reg = model.reg_params
    if 'importance' in mas_reg and 'optpar' in mas_reg:
        importance_dict_list = mas_reg['importance']
        optpar_dict_list = mas_reg['optpar']

        for i in range(len(importance_dict_list)):
            for name, param in model.named_parameters():
                importance = importance_dict_list[i][name]
                optpar = optpar_dict_list[i][name]
                importance=importance.to(device)
                param=param.to(device)
                optpar=optpar.to(device)
                if optpar.size(0) == param.size(0):
                    det_loss += (importance * (optpar - param).pow(2)).sum() * reg_lambda
                else:
                    size_optpar = optpar.size(0)
                    det_loss += (importance * (optpar - param[:size_optpar]).pow(2)).sum() * reg_lambda
    return det_loss


def on_task_update(cfg,data_loaders, device, optimizer, model,num_tasks,criterion,scheduler):
    model.train()
    optimizer.zero_grad()

    mas_reg = model.reg_params

    if 'importance' in mas_reg and 'optpar' in mas_reg:
        importance_dict_list = mas_reg['importance']
        optpar_dict_list = mas_reg['optpar']
    else:
        mas_reg['importance'] = []
        mas_reg['optpar'] = []

    model.train(True)
    with torch.set_grad_enabled(True):
        pbar = tqdm(data_loaders['train'],
                    desc='task{}:'.format(num_tasks) + 'calculating fisher')
        for batch_idx, data in enumerate(pbar, start=1):
            batch_size = data[0].shape[0]
            det_target = data[-1].to(device)
            det_score = model(*[x.to(device) for x in data[:-1]])
            det_score = det_score.reshape(-1, cfg.DATA.NUM_CLASSES)  
            det_target = det_target.reshape(-1, cfg.DATA.NUM_CLASSES)  
            det_loss = criterion['MCE'](det_score, det_target)

            
            pbar.set_postfix({
                'lr': '{:.7f}'.format(scheduler.get_last_lr()[0]),
                'det_loss': '{:.5f}'.format(det_loss.item()),
            })
            optimizer.zero_grad()
            det_loss.backward()
            optimizer.step()
            scheduler.step()
    importance_dict = {}
    optpar_dict = {}

    
    for name, param in model.named_parameters():
        optpar_dict[name] = param.data.clone()
        importance_dict[name] = param.grad.data.clone().abs()  

    mas_reg['importance'].append(importance_dict)
    mas_reg['optpar'].append(optpar_dict)

    return mas_reg
