import torch.nn.functional as F
import time
import torch
from tqdm import tqdm
def get_regularized_loss(criterion, det_score, det_target, model,ewc_lambda,device):
    det_loss = criterion['MCE'](det_score, det_target)
    det_loss=det_loss.to(device)
    ewc_reg = model.reg_params
    if 'fisher' in ewc_reg and 'optpar' in ewc_reg:
        fisher_dict_list = ewc_reg['fisher']
        optpar_dict_list = ewc_reg['optpar']

        for i in range(len(fisher_dict_list)):
            for name, param in model.named_parameters():
                fisher = fisher_dict_list[i][name]
                optpar = optpar_dict_list[i][name]
                fisher=fisher.to(device)
                optpar=optpar.to(device)
                if optpar.size(0) == param.size(0):
                    det_loss += (fisher * (optpar - param).pow(2)).sum() * ewc_lambda
                else:
                    size_optpar = optpar.size(0)
                    det_loss += (fisher * (optpar - param[:size_optpar]).pow(2)).sum() * ewc_lambda
    return det_loss


def on_task_update(cfg,data_loaders, device, optimizer, model,num_tasks,criterion,scheduler):
    model.train()
    ewc_reg = model.reg_params
    if 'fisher' in ewc_reg and 'optpar' in ewc_reg:
        fisher_dict_list = ewc_reg['fisher']
        optpar_dict_list = ewc_reg['optpar']
    else:
        ewc_reg['fisher'] = []
        ewc_reg['optpar'] = []
        
    model.train(True)
    with torch.set_grad_enabled(True):
        pbar = tqdm(data_loaders['train'],
                    desc='task{}:'.format(num_tasks)+'calculating fisher')
        for batch_idx, data in enumerate(pbar, start=1):
            batch_size = data[0].shape[0]
            det_target = data[-1].to(device)
            det_score = model(*[x.to(device) for x in data[:-1]])
            det_score = det_score.reshape(-1, cfg.DATA.NUM_CLASSES) 
            det_target = det_target.reshape(-1, cfg.DATA.NUM_CLASSES)
            det_loss = criterion['MCE'](det_score, det_target)

            
            pbar.set_postfix({
                'lr': '{:.7f}'.format(scheduler.get_last_lr()[0]),
                'det_loss': '{:.5f}'.format(det_loss.item()),
            })
            optimizer.zero_grad()
            det_loss.backward()
            optimizer.step()
            scheduler.step()
    fisher_dict = {}
    optpar_dict = {}

    
    for name, param in model.named_parameters():
        optpar_dict[name] = param.data.clone()
        fisher_dict[name] = param.grad.data.clone().pow(2)

    ewc_reg['fisher'].append(fisher_dict)
    ewc_reg['optpar'].append(optpar_dict)

    return ewc_reg
