import os
import torch
from torch import nn
import random
import numpy as np

from utils.quantizemodel import quantize_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def loss_to_psnr(loss, max=1):
  return 10*np.log10(max**2/np.asarray(loss))

def compute_model_rate(model):
    rate_mlp = 0.0
    rate_arm = 0.0
    rate_conv = 0.0
    rate_per_module = model.get_network_rate()
    for model_name, module_rate in rate_per_module.items():
        for _, param_rate in module_rate.items():  # weight, bias
            if model_name == 'arm':
               rate_arm += param_rate
            elif model_name == 'conv_mod':
               rate_conv += param_rate
            rate_mlp += param_rate
    return rate_mlp,rate_arm,rate_conv
def get_mgrid(w_sidelen,h_sidelen, dim=2):
    x = torch.linspace(-1, 1, steps=w_sidelen) 
    y = torch.linspace(-1, 1, steps=h_sidelen) 
    tensors = (x, y) if dim == 2 else (x, ) * dim  
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing='ij'), dim=-1) 
    mgrid = mgrid.unsqueeze(0).permute(0,3,2,1)
    return mgrid

def eval_model(target_mask,target_mask_lossless, args,model,dataloader,img_index):

    criterion = nn.MSELoss().cuda()

    for batch_idx, (img_in,_) in enumerate(dataloader, 0):
        batch_size,_,height,width=img_in.shape
        pixels = img_in.permute(0, 2, 3, 1).view(batch_size,-1, 3).cuda()
        pixels1 = pixels[:,target_mask,:]
        pixels2 = pixels[:,target_mask_lossless,:]
        
        coords = get_mgrid(width,height, 2).cuda()
        print("********************Evalutation with quantization")
        print("********************Starting quantizing models")
        model_q=quantize_model(model,coords,pixels1,args)
        model=model_q

        torch.cuda.empty_cache()
        model.eval()
        model_output,rate = model(coords)
        model_out = torch.zeros_like(pixels)
        model_out[:,target_mask,:]=model_output
        model_out[:,~target_mask,:]=model_output.mean()
        model_out = model_out[:,target_mask_lossless,:]
        bits_rate_eval=rate.sum()/(args.eval_pix_num)   

        bits_rate_eval=rate.sum()/(args.eval_pix_num)
        bits_rate_eval_num = rate.sum()
        loss_mse=criterion(model_out,pixels2)

        psnr_eval=loss_to_psnr(loss_mse.item())

        print("full_image_psnr:",psnr_eval)

        out_network_rate,out_network_rate_arm,out_network_rate_conv=compute_model_rate(model)
        out_network_rate/=(args.eval_pix_num)
        out_network_rate_arm/=(args.eval_pix_num)
        out_network_rate_conv/=(args.eval_pix_num)
        out_network_rate_num, out_network_rate_arm_num,out_network_rate_conv_num = compute_model_rate(model)

        print("********************Evaluation the Image %d-th, BEST PSNR: %0.6f, Print rate %0.6f, Network rate %0.6f. *************************" % (img_index, psnr_eval,bits_rate_eval.item(),out_network_rate))

        torch.cuda.empty_cache()

    return psnr_eval,bits_rate_eval.item(),bits_rate_eval_num.item(),out_network_rate.item(), out_network_rate_num.item(),out_network_rate_arm.item(),out_network_rate_arm_num.item(),out_network_rate_conv.item(),out_network_rate_conv_num.item()


def input_mapping(x, B):
  if B is None:
    return x
  else:
    x_proj = (2.*np.pi*x) @ B.T
    embedding = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1)
    return embedding
