from tqdm import tqdm
import numpy as np
import torch
import copy

def rmse(a,b):
    se = torch.pow(a-b,2)
    return torch.sqrt(se[torch.isfinite(se)].mean())

def eval_model(model, loader, device='cuda:1', at_transform=None, extra_mask=None, probs=False):
    val_losses = []
    model = model.eval().to(device)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader)):
            x, ats = batch
            if probs:
                mlp_yhat,_  = model(x.to(device))
                mlp_yhat = mlp_yhat.cpu()
            else:
                mlp_yhat = model(x.to(device)).cpu()
#             bs = x.size(0)
#             mlp_yhat = mlp.model(x.view(bs, -1).to(device)).view(bs,2, 20, 20).cpu()
            if at_transform:
                ats = at_transform.unnormalize(ats)
                mlp_yhat = at_transform.unnormalize(mlp_yhat)
            for j in range(ats.size(0)):
                if extra_mask is not None:
                    val_losses.append(rmse(ats[j][extra_mask], mlp_yhat[j][extra_mask]))
                else:
                    val_losses.append(rmse(ats[j], mlp_yhat[j]))
    model.cpu()
    return np.array(val_losses)
def eval_rec_model(model, loader, device='cuda:1', ssp_transform=None):
    val_losses = []
    model = model.eval().to(device)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader)):
            x, ats = batch
            xhat  = model(x.to(device), decode=True)[2].cpu()
#             mlp_yhat = mlp_yhat.cpu()
            if ssp_transform:
                x = ssp_transform.unnormalize(x)
                xhat = ssp_transform.unnormalize(xhat)
            for j in range(x.size(0)):
                val_losses.append(rmse(x[j], xhat[j]))
    model.cpu()
    return np.array(val_losses)


def get_losses(model, dataloader, device='cpu', probs=False, at_transform=None,ssp_transform=None, decode=False, extra_mask=None,print_out = True):
    losses = {}
    losses['val'] = eval_model(model, dataloader, device=device,probs=probs)
    if at_transform:
        at_transform = copy.deepcopy(at_transform)
    if ssp_transform:
        ssp_transform = copy.deepcopy(ssp_transform)
    
    if at_transform is not None:
        losses['unnorm_val'] = eval_model(model, dataloader, device=device,probs=probs, at_transform=at_transform)
    
    if extra_mask is not None:
        losses['masked_val'] = eval_model(model, dataloader,device=device, probs=probs, extra_mask=extra_mask)
    
    if extra_mask is not None and at_transform is not None:
        losses['masked_unnorm_val'] = eval_model(model, dataloader, device=device,at_transform=at_transform, probs=probs, extra_mask=extra_mask)
    
    if decode:
        losses['xrec_val'] = eval_rec_model(model, dataloader,device=device,ssp_transform=None)
    if decode and ssp_transform is not None:
        losses['xrec_unnorm_val'] = eval_rec_model(model, dataloader,device=device, ssp_transform=ssp_transform)

    if print_out:
        print('type\trmse\tmse')
        for key in losses:
            all_losses = losses[key]
            print(key, all_losses.mean(), np.power(all_losses, 2).mean())

@torch.no_grad()
def ssprmse(a,b):
    return torch.sqrt(torch.pow(a-b, 2).mean(dim=(-1,-2)))
def eval_na(model, dataloader, device='cpu', inits=None, opt_z=False):

    model = model.to(device)
    xhats = []
#     dec_errors = []
    i = 0
    for batch in tqdm(dataloader):
        x,y = batch
        if inits is None:
            xinit = torch.zeros_like(x).to(device)
        else:
            bs = x.size(0)
            xinit = inits[i:i+bs].to(device)
            i += bs
        
        if opt_z:
            zinit = model.net.net.query_proj(xinit.view(-1,11*231))

            xhat, _ = model.forward(y, zinit=zinit.clone(),device=device,
                              lr=50)
        else:
            xhat, _ = model.forward(y, xinit=xinit.clone(), lr=50)
        xhats.append(xhat)
    rmses = np.array([])
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            rmses = np.concatenate((rmses, ssprmse(xhats[i].cpu(), batch[0].cpu()).numpy()))
    return xhats, rmses