from re import T
import math
import numpy as np
CUDA_VISIBLE_DEVICES='0'
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import logging

from config.args import test_options
from config.config import model_config
from pytorch_msssim import ms_ssim
from compressai.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import ImageFile, Image
from models import *
from utils.testing import test_model, test_one_epoch
from utils.logger import setup_logger
from modules.transform import *
from loss import RateDistortionLoss


log_file = 'log.txt'


def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return -10 * math.log10(1-ms_ssim(a, b, data_range=1.).item())

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

def crop(x, padding):
    return F.interpolate(
        x,
        padding,
        mode='bilinear',
    )

def pad(x, p):
    h, w = x.size(2), x.size(3)
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.interpolate(
        x,
        (new_h,new_w),
        mode='bicubic',
    )
    return x_padded, (h, w)


def main():
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    Image.MAX_IMAGE_PIXELS = None

    args = test_options()
    config = model_config()

    # os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)

    torch.backends.cudnn.deterministic = True

    #if not os.path.exists(os.path.join('./experiments', args.experiment)):
        #os.makedirs(os.path.join('./experiments', args.experiment))
    #setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO,
                        #screen=True, tofile=True)
    #logger_test = logging.getLogger('test')

    '''test_transforms = transforms.Compose([transforms.ToTensor()])
    test_dataset = ImageFolder(args.dataset, split="images", transform=test_transforms)
    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
    )'''

    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"

    net = MLICPlusPlus(config=config)
    net = net.to(device)
    checkpoint = torch.load(args.checkpoint)
    net.load_state_dict(checkpoint['state_dict'])
    net.eval()
    #print(net)
    epoch = checkpoint["epoch"]
    #logger_test.info(f"Start testing!" )
    save_dir = os.path.join('.\\experiments', args.experiment,'q4')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    #start = time.time()
    #test_one_epoch(model=net, test_dataloader=test_dataloader, save_dir=save_dir, epoch=epoch)
    imagedir = 'E:/professional_valid_2020/valid/'
    critation = RateDistortionLoss()
    num = 0
    bpp_all = 0
    psnr_all = 0
    ms_ssim_all = 0
    with torch.no_grad():
        for file in os.listdir(imagedir):
            imgpath = os.path.join(imagedir,file)
            img = Image.open(imgpath ).convert('RGB')
            img_name = file.split('.')[0]
            img_tensor = transforms.ToTensor()(img)
            #img_tensor = img_tensor.cuda()
            img_tensor = img_tensor.unsqueeze(0)
            B, C, H, W = img_tensor.shape
            #print('img_size:',img_tensor.size())
            x_padded, padding = pad(img_tensor, 64)
            pad_h = 0
            pad_w = 0
            if H % 64 != 0:
                pad_h = 64 * (H // 64 + 1) - H
            if W % 64 != 0:
                pad_w = 64 * (W // 64 + 1) - W
            img_pad = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0)
            out_net = net(x_padded)
            out_net['x_hat'] = crop(out_net['x_hat'],padding)
            bpp_loss = compute_bpp(out_net)
            psnr = compute_psnr(out_net['x_hat'],img_tensor,)
            ms_ssim = compute_msssim(out_net['x_hat'],img_tensor,)

            num += 1
            bpp_all += bpp_loss
            psnr_all += psnr
            ms_ssim_all += ms_ssim
            '''hyper_scales_dir = os.path.join(save_dir,'hyper_scales')
            hyper_means_dir = os.path.join(save_dir, 'hyper_means')
            bpp_dir = os.path.join(save_dir, 'bpp')
            y_dir = os.path.join(save_dir, 'y')
            if not os.path.exists(hyper_scales_dir):
                os.makedirs(hyper_scales_dir)
            if not os.path.exists(hyper_means_dir):
                os.makedirs(hyper_means_dir)
            if not os.path.exists(bpp_dir):
                os.makedirs(bpp_dir)
            if not os.path.exists(y_dir):
                os.makedirs(y_dir)
            hyper_scales = out_net['hyper_scales']
            hyper_means = out_net['hyper_means']
            bpp = out_net["likelihoods"]["y_likelihoods"]
            y = out_net['y']
            bpp = -torch.log(bpp)/math.log(2)
            bpp = bpp.numpy()
            y = y.numpy()
            hyper_scales = hyper_scales.numpy()
            #print(hyper_scales.size())
            hyper_means = hyper_means.numpy()
            np.save(os.path.join(hyper_scales_dir,f'{img_name}.npy'),hyper_scales)
            np.save(os.path.join(hyper_means_dir, f'{img_name}.npy'), hyper_means)
            np.save(os.path.join(bpp_dir, f'{img_name}.npy'), bpp)
            np.save(os.path.join(y_dir, f'{img_name}.npy'), y)'''
            print_context = (file +
                             f'\tbpp: {bpp_loss:.3f}|'
                             f'\tpsnr: {psnr:.3f}'
                             f'\tms_ssim: {ms_ssim:.3f}| \n')
            print(print_context)
            f = open(os.path.join(save_dir, log_file), 'a')
            f.write(print_context)
            f.close()
        bpp_avg = bpp_all/num
        psnr_avg = psnr_all/num
        ms_ssim_avg = ms_ssim_all / num
        print_context = (imagedir + f'\tzero_intra'
                         f'\tbpp_avg: {bpp_avg:.3f}|'
                         f'\tpsnr_avg: {psnr_avg:.3f}'
                         f'\tms_ssim_avg: {ms_ssim_avg:.3f}| \n')
        print(print_context)
        f = open(os.path.join(save_dir, log_file), 'a')
        f.write(print_context)
        f.close()
        #end = time.time()
        #period = end-start
    #print(f'time: {period:.2f}s')

if __name__ == '__main__':
    main()
    '''temp = torch.randn([1,32,28,40])
    ctx = LinearGlobalInterContext()
    ctx(temp)'''