"""
Plot reconstructed images from config. Saves the plots under
/{exp_dir}/summary/recon_images/

Usage:
    python plot_images.py /path/to/config.yaml
"""
import argparse
from typing import Any
import os
from itertools import product
from PIL import Image
import yaml
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import tqdm

IMG_SIZE = (224, 224)  # size of the images to be plotted



def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    return config['model']['pretrained'].replace('/', '_')


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vae', 'results', model_name, dataset_name, exp_name)


def load_recon_images(exp_dir: str, distances: list[float], image_name: str, seed: int) -> dict[str, dict[float, Image.Image]]:
    """
    Load reconstructed images of a given image and seed from the experiment directory.

    Returns:
        dict[int, Image.Image]: distance -> image
    """
    images = {}
    for dist in distances:
        s = seed
        path = os.path.join(exp_dir, f'corr_dist_{dist}', f'noise_seed_{s}', f'{image_name}.png')
        try:
            img = Image.open(path)
        except FileNotFoundError:
            print(f'File not found: {path}')
            img = None
        images[dist] = img
    return images


def main(config: dict[str, Any]):
    exp_dir = parse_output_dir(config)

    # general parameters
    n_clms = len(config['noise']['target_corr_dists'])
    n_rows = 1
    width = 5.5  # width of the figure

    # spacing parameters (in inches)
    h_title = 0  # height of the suptitle
    w_label = 1  # width of the left label
    w_space = 0.01 # horizontal space between images
    h_space = 0.02 # vertical space between images
    s = (width - w_label - (n_clms - 1) * w_space) / n_clms  # size of each image
    height = h_title + s * n_rows + (n_rows - 1) * h_space  # total height (adjusted dynamically)

    # turn inches into fractions
    left = w_label / width
    right = 1.0
    top = 1.0 - h_title / height
    bottom = 0.0
    wspace = w_space / s  
    hspace = h_space / s
    #suptitle_x = (1 + left) / 2  # adjust suptitle x position to be centered against the grid
    #suptitle_y = 1.0 - (h_title / height / 2)  # position of the suptitle

    # titles for each column
    distances = config['noise']['target_corr_dists']
    titles = [rf'$d_H = {d}$' for d in distances]

    # plot each seed and image
    total = len(config['data']['image_names']) * len(config['noise']['noise_seeds'])
    for image_name, seed in tqdm.tqdm(product(config['data']['image_names'], config['noise']['noise_seeds']), total=total):
        # Load images for this seed: dict[distance] = image
        recon_images = load_recon_images(exp_dir, distances, image_name, seed)

        fig = plt.figure(figsize=(width, height))

        # GridSpec with custom margins
        gs = gridspec.GridSpec(
            n_rows, n_clms,
            figure=fig,
            left=left,
            right=right,
            top=top,
            bottom=bottom,
            wspace=wspace,
            hspace=hspace
        )

        for j, dist in enumerate(distances):
            ax = fig.add_subplot(gs[j])
            img = recon_images[dist]

            if img is None:
                # If the image is None, create an empty image
                img = Image.new('RGB', IMG_SIZE, color='gray')
            ax.imshow(img)

            # Hide ticks but keep labels
            ax.set_xticks([])
            ax.set_yticks([])
            ax.spines['left'].set_visible(False)
            ax.spines['bottom'].set_visible(False)

            # Add column titles to top row
            ax.set_title(titles[j], fontsize=6)

        save_dir = os.path.join(exp_dir, 'summary', 'recon_images', image_name)
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, f'{seed}.png'), bbox_inches='tight', pad_inches=0, dpi=600)
        plt.savefig(os.path.join(save_dir, f'{seed}.pdf'), bbox_inches='tight', pad_inches=0, dpi=600)
        plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    # debug mode: only use one image
    parser.add_argument("--debug", action="store_true", help="Run in debug mode with a single image and seed.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    if args.debug:
        config['data']['image_names'] = config['data']['image_names'][:1]  # use only the first image
        config['noise']['noise_seeds'] = [0]  # use only the first seed

    # load and apply plt config
    plt_config_path = 'scripts/250527_readout_vision/configs/plt_config.yaml'
    with open(plt_config_path, 'r') as f:
        plt_config = yaml.safe_load(f)
    plt.rcParams.update(plt_config)

    main(config)