import torch
from datetime import datetime

def compute_stress(_model, tar_Wio):

    _W_io = _model.forward()
    loss_Wio = torch.sum(((_W_io-tar_Wio)**2)/_W_io.numel())
    loss_W = loss_Wio

    return loss_W


def update_DyN(_model, _optim, _sche, tar_Wio, dyn_epochs, meanL_thres=0.01):
    max_loss, mean_loss = 0,0
    for _ep in range(dyn_epochs):
        loss_W = compute_stress(_model, tar_Wio)        
        _optim.zero_grad()
        loss_W.backward()
        _optim.step()
        if _ep%(dyn_epochs//20)==0:
            mean_loss = round(torch.sum(torch.abs(_model.forward()-tar_Wio)).item()/tar_Wio.numel(),8) 
            max_loss = round(torch.max(torch.abs(_model.forward()-tar_Wio)).item(),8)
            print('---', _ep, '- DyN Loss:', loss_W.item(), '- Max Loss:', max_loss, '- Mean Loss:', mean_loss,'- Time:', datetime.now().time())
        if _sche is not None:
            _sche.step()
        if mean_loss < meanL_thres: return _model, max_loss, mean_loss
    
    return _model, max_loss, mean_loss