import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
from pytorch_fid.fid_score import calculate_fid_given_paths
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pytorch_fid.inception import InceptionV3
from lpips_score import cal_lpips_given_paths
from math import log10, sqrt
import cv2
import numpy as np
import glob
import csv
from pyuac import main_requires_admin

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=4,
                    help='Batch size to use')
parser.add_argument('--num-workers', type=int, default = 1,
                    help=('Number of processes to use for data loading. '
                          'Defaults to `min(8, num_cpus)`'))
parser.add_argument('--device', type=str, default=None,
                    help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument('--dims', type=int, default=2048,
                    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
                    help=('Dimensionality of Inception features to use. '
                          'By default, uses pool3 features'))
parser.add_argument('--save-stats', action='store_true',
                    help=('Generate an npz archive from a directory of samples. '
                          'The first path is used as input and the second as output.'))
parser.add_argument('--path', type=str, nargs=2,
                    help=('Paths to the generated images or '
                          'to .npz statistic files'))

args = parser.parse_args()

if args.device is None:
    device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
else:
    device = torch.device(args.device)

if args.num_workers is None:
    try:
        num_cpus = len(os.sched_getaffinity(0))
    except AttributeError:
        # os.sched_getaffinity is not available under Windows, use
        # os.cpu_count instead (which may not return the *available* number
        # of CPUs).
        num_cpus = os.cpu_count()

    num_workers = min(num_cpus, 8) if num_cpus is not None else 0
else:
    num_workers = args.num_workers

def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if mse == 0:
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

def cal_psnr_given_paths(directory):
    total_psnr = 0
    file_list = glob.glob(os.path.join(directory, '*_HR.png'))  # Targets

    for f in file_list:
        fname = os.path.basename(f)
        original = cv2.imread(f)  # Original image
        set_name = fname.rsplit('_', 1)[0]  # Extracting the base name

        # Reading the corresponding generated image
        generated_image_path = os.path.join(directory, set_name + '_SR.png')
        compressed = cv2.imread(generated_image_path)

        value = PSNR(original, compressed)
        total_psnr += value

    avg_psnr = total_psnr / len(file_list)
    return avg_psnr

def cal_rmse_given_paths(directory):
    total_rmse = 0
    file_list = glob.glob(os.path.join(directory, '*_HR.png'))  # Targets

    for f in file_list:
        fname = os.path.basename(f)
        original = cv2.imread(f)  # Original image
        original = np.array(original).astype('float32')
        set_name = fname.rsplit('_', 1)[0]  # Extracting the base name

        # Reading the corresponding generated image
        generated_image_path = os.path.join(directory, set_name + '_SR.png')
        compressed = cv2.imread(generated_image_path)
        compressed = np.array(compressed).astype('float32')

        value = np.sqrt(np.mean(np.square(original - compressed)))
        total_rmse += value

    avg_rmse = total_rmse / len(file_list)
    return avg_rmse

def evaluate_model(directory, args):
    # Set the paths for the current model's images
    real_images = glob.glob(os.path.join(directory, '*_HR.png')) 
    fake_images = glob.glob(os.path.join(directory, '*_SR.png'))

    real_path = os.path.join(directory, 'real_images')
    fake_path = os.path.join(directory, 'fake_images')

    os.makedirs(real_path, exist_ok=True)
    os.makedirs(fake_path, exist_ok=True)

    # Helper function to create symlinks if they don't exist
    def create_symlink(src, dst):
        # If the symlink already exists, we remove it
        if os.path.exists(dst):
            os.remove(dst)
        os.symlink(src, dst)

    for image in real_images:
        symlink_path = os.path.join(real_path, os.path.basename(image))
        create_symlink(os.path.abspath(image), symlink_path)

    for image in fake_images:
        symlink_path = os.path.join(fake_path, os.path.basename(image))
        create_symlink(os.path.abspath(image), symlink_path)

    fid_value = calculate_fid_given_paths([real_path, fake_path],
                                          args.batch_size,
                                          device,
                                          args.dims,
                                          args.num_workers)
    psnr_value = cal_psnr_given_paths(directory)
    lpips_value = cal_lpips_given_paths([real_path, fake_path],
                                        args.batch_size,
                                        device,
                                        args.num_workers)
    rmse_value = cal_rmse_given_paths(directory)

    for entry in os.scandir(real_path):
        os.remove(entry.path)
    for entry in os.scandir(fake_path):
        os.remove(entry.path)

    os.rmdir(real_path)
    os.rmdir(fake_path)

    print(f'Results for model in {directory}: FID {fid_value:.4f}, PSNR {psnr_value:.4f}, LPIPS {lpips_value:.4f}, RMSE {rmse_value:.4f}')
    return fid_value, psnr_value, lpips_value, rmse_value

@main_requires_admin
def main(parent_dir, args):
    model_results = {}
    best_fid = float('inf')  # Initialize with infinity to find minimum
    best_rmse = float('inf')
    best_lpips = float('inf')
    best_psnr = 0  # Initialize with 0 to find maximum
    
    best_fid_ckpt = ''
    best_rmse_ckpt = ''
    best_psnr_ckpt = ''
    best_lpips_ckpt = ''

    # Prepare the CSV file to record all evaluation metrics
    with open(os.path.join(parent_dir, 'evaluation_results.csv'), mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Model', 'FID', 'PSNR', 'LPIPS', 'RMSE', 'Checkpoint'])  # Header of CSV

        for subdir in os.listdir(parent_dir):
            full_model_path = os.path.join(parent_dir, subdir)
            if os.path.isdir(full_model_path):
                fid_value, psnr_value, lpips_value, rmse_value = evaluate_model(full_model_path, args)
                model_results[subdir] = {
                    'fid': fid_value,
                    'psnr': psnr_value,
                    'lpips': lpips_value,
                    'rmse': rmse_value
                }

                # Assuming the best checkpoint is named 'model_best.ckpt'
                # ckpt_path = os.path.join(full_model_path, 'model_best.ckpt')
                ckpt_path = full_model_path

                # Record the best FID and PSNR along with the corresponding checkpoint
                if fid_value < best_fid:
                    best_fid = fid_value
                    best_fid_ckpt = ckpt_path

                if psnr_value > best_psnr:
                    best_psnr = psnr_value
                    best_psnr_ckpt = ckpt_path

                if lpips_value < best_lpips:
                    best_lpips = lpips_value
                    best_lpips_ckpt = ckpt_path
                
                if rmse_value < best_rmse:
                    best_rmse = rmse_value
                    best_rmse_ckpt = ckpt_path

                # Write the results to the CSV file
                writer.writerow([subdir, fid_value, psnr_value, lpips_value, rmse_value, ckpt_path])
        
        writer.writerows([['\n'],
                        [f"Best FID: {best_fid} (Checkpoint: {best_fid_ckpt})\n"],
                        [f"Best PSNR: {best_psnr} (Checkpoint: {best_psnr_ckpt})\n"],
                        [f"Best LPIPS: {best_lpips} (Checkpoint: {best_lpips_ckpt})\n"],
                        [f"Best RMSE: {best_rmse} (Checkpoint: {best_rmse_ckpt})\n"]])

    # Output the summary of results and best values
    for model_name, results in model_results.items():
        print(f"Model: {model_name}")
        print(f"    FID: {results['fid']}")
        print(f"    PSNR: {results['psnr']}")
        print(f"    LPIPS: {results['lpips']}")
        print(f"    RMSE: {results['rmse']}")
    
    print(f"Best FID: {best_fid} (Checkpoint: {best_fid_ckpt})")
    print(f"Best PSNR: {best_psnr} (Checkpoint: {best_psnr_ckpt})")
    print(f"Best LPIPS: {best_lpips} (Checkpoint: {best_lpips_ckpt})")
    print(f"Best RMSE: {best_rmse} (Checkpoint: {best_rmse_ckpt})")


if __name__ == '__main__':
    parent_dir = r"E:\Improved_BBDM\outputs\BBDM-20231107-2345" 
    main(parent_dir, args)