import torch
from datetime import datetime

def compute_stress(_model, tar_Wio):

    _W_io = _model.forward()
    loss_Wio = 0
    if isinstance(_W_io, list):
        num_output, num_input, kernel_size, _ = tar_Wio.shape
        for row_id in range(kernel_size):
            for col_id in range(kernel_size):
                loss_Wio += torch.sum(((_W_io[row_id*kernel_size+col_id]-tar_Wio[:,:,row_id,col_id])**2)/(num_output*num_input))

    else:
        loss_Wio = torch.sum(((_W_io-tar_Wio)**2)/_W_io.numel())
    return loss_Wio


def update_DyN(_model, _optim, _sche, tar_Wio, dyn_epochs, meanL_thres=0.01, logger=None):
    max_loss, mean_loss = 0,0
    for _ep in range(dyn_epochs):
        loss_W = compute_stress(_model, tar_Wio)        
        _optim.zero_grad()
        loss_W.backward()
        _optim.step()
        if _ep%(dyn_epochs//20)==0:
            if len(tar_Wio.shape) == 2:
                mean_loss = round(torch.sum(torch.abs(_model.forward()-tar_Wio)).item()/tar_Wio.numel(),8) 
                max_loss = round(torch.max(torch.abs(_model.forward()-tar_Wio)).item(),8)
            elif len(tar_Wio.shape) == 4:
                num_output, num_input, kernel_size, _ = tar_Wio.shape
                FTs = _model.forward()
                mean_loss = round(torch.sum(torch.abs(FTs-tar_Wio)).item()/tar_Wio.numel(),8) 
                max_loss = round(torch.max(torch.abs(FTs-tar_Wio)).item(),8)

            logger.info('---{}- DyN Loss:{}- Max Loss:{}- Mean Loss:{}- Time:{}'.format(_ep, loss_W.item(), max_loss, mean_loss, datetime.now().time()))
        if _sche is not None:
            _sche.step()
        if mean_loss < meanL_thres: return _model, max_loss, mean_loss
    
    return _model, max_loss, mean_loss