
import sys


import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(current_dir)
sys.path.append(parent_dir)
import numpy as np
import torch
import os, imageio, json
from tqdm import tqdm
from skimage.metrics import structural_similarity

from tqdm import tqdm, trange
import cv2
import numpy as np
import lpips
import torch
from pdb import set_trace as bp


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def config_parser():
    import configargparse
    parser = configargparse.ArgumentParser()
    parser.add_argument("--fm_dir", type=str, default=None, help='where to images by foundation model to be compared.')
    # parser.add_argument("--input_prefix", type=str, help='prefix of input images, such as train00, train01, ..., train04.')
    parser.add_argument("--hyfluid_dir", type=str, default=None, help='where to images by Hyfluid to be compared.')
    parser.add_argument("--gt_dir", type=str, help='where to rgb images of ground truth.')
    parser.add_argument("--frame_start", type=int, default=0, help='index of the first frame')
    parser.add_argument("--frame_end", type=int, default=20, help='end of numbers of frames tested thr')
    parser.add_argument("--out_name", type=str, default='', help='end of numbers of frames tested thr')

    return parser


def main():
    parser = config_parser()
    args = parser.parse_args()

    psnrs = []
    ssims = []
    lpipss = []
    metrics = {
        'avg_psnr': list(),
        'avg_lpips': list(),
        'avg_ssim': list()
    }
    lpips_net = lpips.LPIPS().cuda()
    for idx in tqdm(range(args.frame_start, args.frame_start+args.frame_end)):
        if args.fm_dir:
            prefix = args.fm_dir.split('/')[-1]
            image = cv2.imread(os.path.join(args.fm_dir, f"{prefix}_{idx}.png"))
            f_name = args.fm_dir.split('/')[-3]
        else:
            image = cv2.imread(os.path.join(args.hyfluid_dir, f"rgb_{str(idx).zfill(3)}.png"))
            f_name = args.hyfluid_dir.split('/')[-3]
        rgb = image/255.
        gt_img = cv2.imread(os.path.join(args.gt_dir, f"gt_{str(idx).zfill(3)}.png"))/255.
        if len(args.out_name):
            path='logs/single_120_'+args.out_name+f'_s{args.frame_start}_e{args.frame_end}.txt'
        else:
            path='logs/single_120_'+f_name+f'_s{args.frame_start}_e{args.frame_end}.txt'

        if gt_img.shape[0] == rgb.shape[0] // 2:
            rgb = rgb[::2, ::2]
        assert gt_img.shape == rgb.shape
        gt_img = gt_img[90:960, 45:540]
        rgb = rgb[90:960, 45:540]
        p = -10. * np.log10(np.mean(np.square(rgb - gt_img)))
        ssim_value = structural_similarity(gt_img, rgb, data_range=1.0, channel_axis=2)
        gt_img = torch.tensor(gt_img.astype(np.float32)).cuda()
        rgb = torch.tensor(rgb.astype(np.float32)).cuda()
        lpips_value = lpips_net(rgb.permute(2, 0, 1), gt_img.permute(2, 0, 1), normalize=True).item()
        print(f'PSNR: {p:.4g}, SSIM: {ssim_value:.4g}, LPIPS: {lpips_value:.4g}')
        metrics['avg_psnr'].append(float(p))
        metrics['avg_lpips'].append(float(lpips_value))
        metrics['avg_ssim'].append(float(ssim_value))
        with open(path, 'a') as f:
            f.write(f'{idx}\t{p}\t{lpips_value}\t{ssim_value}\n')
    print('Export to:'+path)
    print(metrics)  

if __name__ == '__main__':
    import ipdb
    try:
        main()
    except Exception as e:
        print(e)
        ipdb.post_mortem()

