import os
import sys
import cv2
import tqdm
import glob
import torch
import argparse
import numpy as np
import os.path as osp
from omegaconf import OmegaConf

sys.path.append('.')
from utils.utils import InputPadder, read, img2tensor
from utils.build_utils import build_from_cfg
from metrics.psnr_ssim import calculate_psnr, calculate_ssim

parser = argparse.ArgumentParser(
                prog = 'AMT',
                description = 'Xiph evaluation',
                )
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') 
parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth') 
parser.add_argument('-r', '--root', default='data/xiph') 
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = args.config
ckpt_path = args.ckpt
root = args.root

network_cfg = OmegaConf.load(cfg_path).network
network_name = network_cfg.name
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'], False)
model = model.to(device)
model.eval()

############################################# Prepare Dataset #############################################
download_links = [
    'https://media.xiph.org/video/derf/ElFuente/Netflix_BoxingPractice_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/ElFuente/Netflix_Crosswalk_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/Chimera/Netflix_DrivingPOV_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket2_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/ElFuente/Netflix_RitualDance_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/ElFuente/Netflix_SquareAndTimelapse_4096x2160_60fps_10bit_420.y4m',
    'https://media.xiph.org/video/derf/ElFuente/Netflix_Tango_4096x2160_60fps_10bit_420.y4m',
]
file_list = ['BoxingPractice', 'Crosswalk', 'DrivingPOV', 'FoodMarket', 'FoodMarket2', 'RitualDance', 
             'SquareAndTimelapse', 'Tango']

for file_name, link in zip(file_list, download_links):
    data_dir = osp.join(root, file_name)
    if osp.exists(data_dir) is False:
        os.makedirs(data_dir)
    if len(glob.glob(f'{data_dir}/*.png')) < 100:
        os.system(f'ffmpeg -i {link} -pix_fmt rgb24 -vframes 100 {data_dir}/%03d.png')
############################################### Prepare End ###############################################


divisor = 32; scale_factor = 0.5
for category in ['resized-2k', 'cropped-4k']:
    psnr_list = []
    ssim_list = []
    pbar = tqdm.tqdm(file_list, total=len(file_list))
    for flie_name in pbar:
        dir_name = osp.join(root, flie_name)
        for intFrame in range(2, 99, 2):
            img0 = read(f'{dir_name}/{intFrame - 1:03d}.png')
            img1 = read(f'{dir_name}/{intFrame + 1:03d}.png')
            imgt = read(f'{dir_name}/{intFrame:03d}.png')

            if category == 'resized-2k':
                img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
                img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
                imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)

            elif category == 'cropped-4k':
                img0 = img0[540:-540, 1024:-1024, :]
                img1 = img1[540:-540, 1024:-1024, :]
                imgt = imgt[540:-540, 1024:-1024, :]
            img0 = img2tensor(img0).to(device)
            imgt = img2tensor(imgt).to(device)
            img1 = img2tensor(img1).to(device)
            embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
            
            padder = InputPadder(img0.shape, divisor)
            img0, img1 = padder.pad(img0, img1)

            with torch.no_grad():
                imgt_pred = model(img0, img1, embt, scale_factor=scale_factor, eval=True)['imgt_pred']
                imgt_pred = padder.unpad(imgt_pred)

            psnr = calculate_psnr(imgt_pred, imgt)
            ssim = calculate_ssim(imgt_pred, imgt)

            avg_psnr = np.mean(psnr_list)
            avg_ssim = np.mean(ssim_list)
            psnr_list.append(psnr)
            ssim_list.append(ssim)
            desc_str = f'[{network_name}/Xiph] [{category}/{flie_name}] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}'

            pbar.set_description_str(desc_str)