import sys
import tqdm
import torch
import argparse
import numpy as np
import os.path as osp
from omegaconf import OmegaConf

sys.path.append('.')
from utils.utils import read, img2tensor
from utils.build_utils import build_from_cfg
from metrics.psnr_ssim import calculate_psnr, calculate_ssim

parser = argparse.ArgumentParser(
                prog = 'AMT',
                description = 'Vimeo90K evaluation (with Test-Time Augmentation)',
                )
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') 
parser.add_argument('p', '--ckpt', default='pretrained/amt-s.pth',) 
parser.add_argument('-r', '--root', default='data/vimeo_triplet',) 
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = args.config
ckpt_path = args.ckpt
root = args.root

network_cfg = OmegaConf.load(cfg_path).network
network_name = network_cfg.name
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'])
model = model.to(device)
model.eval()

with open(osp.join(root, 'tri_testlist.txt'), 'r') as fr:
    file_list = fr.readlines()

psnr_list = []
ssim_list = []

pbar = tqdm.tqdm(file_list, total=len(file_list))
for name in pbar:
    name = str(name).strip()
    if(len(name) <= 1):
        continue
    dir_path = osp.join(root, 'sequences', name)
    I0 = img2tensor(read(osp.join(dir_path, 'im1.png'))).to(device)
    I1 = img2tensor(read(osp.join(dir_path, 'im2.png'))).to(device)
    I2 = img2tensor(read(osp.join(dir_path, 'im3.png'))).to(device)
    embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)

    I1_pred1 = model(I0, I2, embt, 
                        scale_factor=1.0, eval=True)['imgt_pred']
    I1_pred2 = model(torch.flip(I0, [2]), torch.flip(I2, [2]), embt, 
                        scale_factor=1.0, eval=True)['imgt_pred']
    I1_pred = I1_pred1 / 2 + torch.flip(I1_pred2, [2]) / 2
    psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy()
    ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy()

    psnr_list.append(psnr)
    ssim_list.append(ssim)
    avg_psnr = np.mean(psnr_list)
    avg_ssim = np.mean(ssim_list)
    desc_str = f'[{network_name}/Vimeo90K] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}'
    pbar.set_description_str(desc_str)

