"""
Calculate perceptual metrics (SSIM, PSNR, LPIPS, and DISTS) between
original and reconstructed images.
The results are saved in a CSV file under the experiment output directory.

Usage:
    python calculate_perceptual_metrics.py /path/to/config.yaml
"""
import argparse
import os
from itertools import product
from typing import Any

import yaml
import numpy as np
import pandas as pd
import torch
import tqdm
from matplotlib import pyplot as plt
import lpips
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import DISTS_pytorch

IMAGE_SIZE = (224, 224)


def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    # if alias is provided, use it
    if config['model'].get('model_alias'):
        return config['model']['model_alias']
    if config['model']['name'].endswith('-tfm'):
        # transformer model: use pretrained name by replacing '/' with '_'
        return config['model']['pretrained'].replace('/', '_')
    # otherwise, use the model name
    return config['model']['name']


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vision', 'results', model_name, dataset_name, exp_name)


def load_images(config: dict):
    # load true images
    true_images = {}
    image_dir = config['data']['image_dir']
    image_names = config['data']['image_names']
    image_ext = config['data']['image_ext']
    for name in image_names:
        path = os.path.join(image_dir, f"{name}{image_ext}")
        img_pil = Image.open(path).convert('RGB')
        img_pil = img_pil.resize(IMAGE_SIZE)
        true_images[name] = img_pil

    # load reconstructed images
    # list of all reconstructed image parameters
    recon_images = []
    exp_dir = parse_output_dir(config)
    for name, layer, dist in product(
        config['data']['image_names'],
        config['layers'],
        config['noise']['target_corr_dists']
    ):
        seeds = config['noise']['noise_seeds'] if dist != 0 else [None]
        for seed in seeds:
            # load the image
            path = os.path.join(
                exp_dir, layer, f'corr_dist_{dist}', f'noise_seed_{seed}',
                name, 'final.png'
            )
            image = Image.open(path)
            recon_images.append({
                'name': name,
                'layer': layer,
                'distance': dist,
                'seed': seed,
                'image': image 
            })
    return true_images, recon_images


def pil_to_numpy(img_pil):
    img_np = np.array(img_pil).astype(np.float32) / 255.0
    return img_np


def pil_to_torch(img_pil):
    img_np = pil_to_numpy(img_pil)
    img_t = torch.tensor(img_np).permute(2, 0, 1).unsqueeze(0) * 2 - 1  # [-1, 1]
    return img_t.float()


def measure_similarity(
        img_pil_1: Image.Image, img_pil_2: Image.Image,
        lpips_fn: lpips.LPIPS, dists_fn
        ):
    # Resize or check that the sizes match
    if img_pil_1.size != img_pil_2.size:
        raise ValueError("Images must have the same size for comparison")

    # Convert to numpy
    img1_np = pil_to_numpy(img_pil_1)
    img2_np = pil_to_numpy(img_pil_2)

    # SSIM and PSNR
    ssim_value = ssim(img1_np, img2_np, channel_axis=2, data_range=1.0)
    psnr_value = psnr(img1_np, img2_np, data_range=1.0)

    # Convert to torch
    img1_t = pil_to_torch(img_pil_1)
    img2_t = pil_to_torch(img_pil_2)

    # LPIPS and DISTS
    lpips_value = lpips_fn(img1_t, img2_t).item()
    dists_value = dists_fn(img1_t, img2_t).item()

    return {
        "SSIM": ssim_value,
        "PSNR": psnr_value,
        "LPIPS": lpips_value,
        "DISTS": dists_value
    }
    

def main(config: dict):
    # true_images: name -> pil image
    # recon_images: list[dict]
    true_images, recon_images = load_images(config)

    # load metrics models
    lpips_fn = lpips.LPIPS(net='vgg').eval()
    dists_fn = DISTS_pytorch.DISTS().eval()

    # evaluate images
    for recon in tqdm.tqdm(recon_images, desc="Calculating perceptual metrics"):
        name = recon['name']
        true_image = true_images[name]
        recon_image = recon['image']
        metrics = measure_similarity(true_image, recon_image, lpips_fn, dists_fn)
        recon.update(metrics)
        del recon['image']  # remove image from dict

    # save results as a csv file
    results_df = pd.DataFrame(recon_images)
    output_dir = parse_output_dir(config)
    path = os.path.join(output_dir, 'perceptual_metrics.csv')
    results_df.to_csv(path, index=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    main(config)