import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from skimage.metrics import structural_similarity as ssim
from pytorch_fid.fid_score import calculate_fid_given_paths
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pytorch_fid.inception import InceptionV3
from lpips_score import cal_lpips_given_paths
from math import log10, sqrt
import cv2
import numpy as np
import glob
import csv

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=4,
                    help='Batch size to use')
parser.add_argument('--num-workers', type=int, default = 1,
                    help=('Number of processes to use for data loading. '
                          'Defaults to `min(8, num_cpus)`'))
parser.add_argument('--device', type=str, default=None,
                    help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument('--dims', type=int, default=2048,
                    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
                    help=('Dimensionality of Inception features to use. '
                          'By default, uses pool3 features'))
parser.add_argument('--save-stats', action='store_true',
                    help=('Generate an npz archive from a directory of samples. '
                          'The first path is used as input and the second as output.'))
parser.add_argument('--path', type=str, nargs=2,
                    help=('Paths to the generated images or '
                          'to .npz statistic files'))

args = parser.parse_args()

if args.device is None:
    device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
else:
    device = torch.device(args.device)

if args.num_workers is None:
    try:
        num_cpus = len(os.sched_getaffinity(0))
    except AttributeError:
        # os.sched_getaffinity is not available under Windows, use
        # os.cpu_count instead (which may not return the *available* number
        # of CPUs).
        num_cpus = os.cpu_count()

    num_workers = min(num_cpus, 8) if num_cpus is not None else 0
else:
    num_workers = args.num_workers

def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if mse == 0:
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

def cal_psnr_given_paths(directory):
    total_psnr = 0
    file_list = glob.glob(os.path.join(directory[1], '*.png'))  # Targets

    for f in file_list:
        fname = os.path.basename(f)
        original = cv2.imread(f)  # Original image
        
        # Reading the corresponding generated image
        generated_image_path = os.path.join(directory[0], fname)
        # potential_paths = [os.path.join(directory, set_name + suffix) for suffix in ['.png']]
        # generated_image_path = next((path for path in potential_paths if not path.endswith('LQ.png') and not path.endswith('HR.png')), None)
        compressed = cv2.imread(generated_image_path)

        value = PSNR(original, compressed)
        total_psnr += value

    avg_psnr = total_psnr / len(file_list)
    return avg_psnr

def cal_rmse_given_paths(directory):
    total_rmse = 0
    file_list = glob.glob(os.path.join(directory[1], '*.png'))  # Targets

    for f in file_list:
        fname = os.path.basename(f)
        original = cv2.imread(f)  # Original image
        original = np.array(original).astype('float32')
    
        # Reading the corresponding generated image
        generated_image_path = os.path.join(directory[0], fname)
        # potential_paths = [os.path.join(directory, set_name + suffix) for suffix in ['.png']]
        # generated_image_path = next((path for path in potential_paths if not path.endswith('LQ.png') and not path.endswith('HR.png')), None)
        compressed = cv2.imread(generated_image_path)
        compressed = np.array(compressed).astype('float32')

        value = np.sqrt(np.mean(np.square(original - compressed)))
        total_rmse += value

    avg_rmse = total_rmse / len(file_list)
    return avg_rmse

def cal_ssim_given_paths(directory):
    total_ssim = 0
    file_list = glob.glob(os.path.join(directory[1], '*.png'))  # Targets

    for f in file_list:
        fname = os.path.basename(f)
        original = cv2.imread(f, cv2.IMREAD_GRAYSCALE)  # Original image converted to GRAYSCALE
        
        # Reading the corresponding generated image
        generated_image_path = os.path.join(directory[0], fname)
        # potential_paths = [os.path.join(directory, set_name + suffix) for suffix in ['.png']]
        # generated_image_path = next((path for path in potential_paths if not path.endswith('LQ.png') and not path.endswith('HR.png')), None)
        compressed = cv2.imread(generated_image_path, cv2.IMREAD_GRAYSCALE)

        value = ssim(original, compressed)
        total_ssim += value

    avg_ssim = total_ssim / len(file_list)

    return avg_ssim

def evaluate_model(directory, args):
    # # Set the paths for the current model's images
    # real_images = glob.glob(os.path.join(directory, '*_HR.png')) 
    # fake_images = glob.glob(os.path.join(directory, '*_SR.png'))
    # # fake_images = [img for img in glob.glob(os.path.join(directory, '*.png')) if not img.endswith('LQ.png') and not img.endswith('HR.png')]

    # real_path = os.path.join(directory, 'real_images')
    # fake_path = os.path.join(directory, 'fake_images')

    # os.makedirs(real_path, exist_ok=True)
    # os.makedirs(fake_path, exist_ok=True)

    # Helper function to create symlinks if they don't exist
    # def create_symlink(src, dst):
    #     # If the symlink already exists, we remove it
    #     if os.path.exists(dst):
    #         os.remove(dst)
    #     os.symlink(src, dst)

    # for image in real_images:
    #     symlink_path = os.path.join(real_path, os.path.basename(image))
    #     create_symlink(os.path.abspath(image), symlink_path)

    # for image in fake_images:
    #     symlink_path = os.path.join(fake_path, os.path.basename(image))
    #     create_symlink(os.path.abspath(image), symlink_path)

    fid_value = calculate_fid_given_paths(directory,
                                          args.batch_size,
                                          device,
                                          args.dims,
                                          args.num_workers)
    psnr_value = cal_psnr_given_paths(directory)
    lpips_value = cal_lpips_given_paths(directory,
                                        args.batch_size,
                                        device,
                                        args.num_workers)
    rmse_value = cal_rmse_given_paths(directory)
    ssim_value = cal_ssim_given_paths(directory)

    # for entry in os.scandir(real_path):
    #     os.remove(entry.path)
    # for entry in os.scandir(fake_path):
    #     os.remove(entry.path)

    # os.rmdir(real_path)
    # os.rmdir(fake_path)

    print(f'Results for model in {directory}: FID {fid_value:.4f}, PSNR {psnr_value:.4f}, LPIPS {lpips_value:.4f}, RMSE {rmse_value:.4f}, SSIM {ssim_value: .4f}')
    return fid_value, psnr_value, lpips_value, rmse_value, ssim_value

def evaluate_all_models(parent_dir, args):
    model_results = {}

    # Prepare the CSV file to record all evaluation metrics
    with open(os.path.join(os.path.dirname(parent_dir[0]), 'evaluation_results.csv'), mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Model', 'FID', 'PSNR', 'LPIPS', 'RMSE', 'SSIM', 'Checkpoint'])  # Header of CSV

        fid_value, psnr_value, lpips_value, rmse_value, ssim_value = evaluate_model(parent_dir, args)
        model_results = {
        'fid': fid_value,
        'psnr': psnr_value,
        'lpips': lpips_value,
        'rmse': rmse_value,
        'ssim': ssim_value
        }

        # Write the results to the CSV file
        writer.writerow([fid_value, psnr_value, lpips_value, rmse_value, ssim_value])

    
    # Output the summary of results and best values
    print(f"FID: {model_results['fid']}")
    print(f"PSNR: {model_results['psnr']}")
    print(f"LPIPS: {model_results['lpips']}")
    print(f"RMSE: {model_results['rmse']}")
    print(f"SSIM: {model_results['ssim']}")


if __name__ == '__main__':
    # parent_dir = "/run/media/ozcan170/Storage1/Improved_BBDM-master/outputs/bbdm_istd" 
    parent_dir = ['output', 'target']
    evaluate_all_models(parent_dir, args)