import torch
from image_param_list import params_mlnp
from net import *
from data_prepare import *
from torchvision.utils import make_grid
from PIL import Image
import os


def evalution_meta_net(meta_net,
                  dim,
                  eval_loader,
                  num_c_points=None,order=False,num_z=32,eval_mode='t'):
    meta_net.eval()
    epoch_test_nll_t = []
    epoch_test_nll_c = []
    epoch_test_nll_p = []
    epoch_test_mse = []

    with torch.no_grad():
        for batch_idx, (y_all, _) in enumerate(eval_loader):
            batch_size = y_all.shape[0]
            if dim == 28:
                y_all = y_all.permute(0,2,3,1).contiguous().view(batch_size, -1, 1).cuda()
            elif dim == 32:
                y_all = y_all.permute(0,2,3,1).contiguous().view(batch_size, -1, 3).cuda()
            if num_c_points == None:
                N = random.randint(1, dim*dim-1)
            else:
                N = num_c_points

            idx = get_context_idx(N, dim*dim, order_pixels=order)
            idx_list = idx.tolist()
            idx_all = np.arange(dim*dim).tolist()
            x_c = idx_to_x(idx, batch_size, dim)
            y_c = idx_to_y(idx, y_all)
            idx_all_tensor = torch.tensor(idx_all,dtype=torch.long).cuda()
            x = idx_to_x(idx_all_tensor, batch_size, dim).cuda()
            y = idx_to_y(idx_all_tensor, y_all).cuda()

            pred_idx = torch.tensor(list(set(idx_all)-set(idx_list)), dtype=torch.long).cuda()
            x_t = idx_to_x(pred_idx, batch_size, dim).cuda()
            y_t = idx_to_y(pred_idx, y_all).cuda()
            y_ = y.permute(0,2,1).view(y.size()[0],y.size()[2],dim,dim)
            n_total = dim*dim
            if num_c_points == None:
                num_context = int(torch.empty(1).uniform_(1, n_total-1).item())
            else:
                num_context = num_c_points
            mask = y_.new_empty(y_.size(0), 1, y_.size(2), y_.size(3)).bernoulli_(p=num_context / n_total)
            mask_p = torch.ones_like(mask)
            mask_p = mask_p - mask
            mu, logvar, b_nll_t, y_mean=meta_net.conditional_predict(y_,mask)
            _, _, b_nll_c, _=meta_net.conditional_predict(y_,mask,mask)
            _, _, b_nll_p, _=meta_net.conditional_predict(y_,mask,mask_p)
            b_avg_nll_t = b_nll_t/(mask.size()[-1]*mask.size()[-2])
            b_avg_nll_c = b_nll_c/(torch.sum(mask)/mask.size()[0])
            b_avg_nll_p = b_nll_p/(torch.sum(mask_p)/mask_p.size()[0])
            b_avg_mse=F.mse_loss(y_mean,y)

            epoch_test_nll_t.append(b_avg_nll_t.cpu())
            epoch_test_nll_c.append(b_avg_nll_c.cpu())
            epoch_test_nll_p.append(b_avg_nll_p.cpu())
            epoch_test_mse.append(b_avg_mse.cpu())
    avg_te_nll_c = np.array(epoch_test_nll_c).sum() /len(eval_loader)
    avg_te_nll_t = np.array(epoch_test_nll_t).sum() /len(eval_loader)
    avg_te_nll_p = np.array(epoch_test_nll_p).sum() /len(eval_loader)
    avg_te_mse = np.array(epoch_test_mse).sum() /len(eval_loader)

    return avg_te_nll_t, avg_te_nll_c, avg_te_nll_p, avg_te_mse
def evalution(check_lvm,image_data,alpha=0.5,eval_mode='average',writer=1,load_writer=4,num_z=32,num_c_points=None,order=False,seed=7,save=False):
    args,_,device = params_mlnp()
    args.type = check_lvm
    _,_,eval_loader=cifar10_metadataset(b_size=8)
    dim = 32
    args.y_dim = 3
    args.alpha = alpha
    if alpha!=0.:
        args.eval_mode = 'CVAR'
    save_path = './final_results_'+image_data+'/'+check_lvm+'/'+str(seed)+'/'+str(load_writer)
    check_state = torch.load(save_path+'/'+check_lvm+'.pth')
    meta_net = conv_net(args).to(device)
    meta_net.load_state_dict(check_state)
    avg_te_nll_t,avg_te_nll_c,avg_te_nll_p, avg_te_mse = evalution_meta_net(meta_net, dim, eval_loader,
                                                 num_c_points=num_c_points,order=order)
    results = []
    results.append(-avg_te_nll_t)
    results.append(-avg_te_nll_c)
    results.append(-avg_te_nll_p)
    results.append(avg_te_mse)
    meta_te_results = np.array(results)
    save_path = './final_eval_results_'+image_data+'/'+check_lvm+'/'+str(seed)+'/'+str(writer)
    os.makedirs(save_path, exist_ok=True)
    if save:
        np.savetxt(os.path.join(save_path, 'te_nll_list.csv'), meta_te_results)
    return meta_te_results
