"""
Plot reconstructed images from config. Saves the plots under
/{exp_dir}/summary/recon_images/

Usage:
    python plot_images.py /path/to/config.yaml
"""
import argparse
from typing import Any
import os
from itertools import product
from PIL import Image
import yaml
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import tqdm

IMG_SIZE = (224, 224)  # size of the images to be plotted


# TODO: import them from the common scripts
def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    # if alias is provided, use it
    if config['model'].get('model_alias'):
        return config['model']['model_alias']
    if config['model']['name'].endswith('-tfm'):
        # transformer model: use pretrained name by replacing '/' with '_'
        return config['model']['pretrained'].replace('/', '_')
    # otherwise, use the model name
    return config['model']['name']


def parse_feature_dir(config: dict[str, Any]) -> str:
    if 'feature_dir' in config:
        return config['feature_dir']
    model_name = parse_model_name(config)
    return os.path.join('output', 'readout_vision', 'features', model_name, config['data']['dataset_name']) 


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vision', 'results', model_name, dataset_name, exp_name)


def load_recon_images(exp_dir: str, layers: list[str], distances: list[float], image_name: str, seed: int) -> dict[str, dict[float, Image.Image]]:
    """
    Load reconstructed images of a given image and seed from the experiment directory.

    Returns:
        dict[str, dict[str, dict[int, Image.Image]]]: layer -> distance -> image
    """
    images = {layer: {} for layer in layers}

    for layer, dist in product(layers, distances):
        s = seed if dist != 0 else None  # no noise = no noise random seed
        path = os.path.join(exp_dir, layer, f'corr_dist_{dist}', f'noise_seed_{s}', image_name, 'final.png')
        try:
            img = Image.open(path)
        except FileNotFoundError:
            print(f'File not found: {path}')
            img = None
        images[layer][dist] = img
    return images


def main(config: dict[str, Any]):
    exp_dir = parse_output_dir(config)

    # general parameters
    n_clms = len(config['noise']['target_corr_dists'])
    n_rows = len(config['layers'])
    width = 5.5  # width of the figure

    # spacing parameters (in inches)
    h_title = 0  # height of the suptitle
    w_label = 1  # width of the left label
    w_space = 0.01 # horizontal space between images
    h_space = 0.02 # vertical space between images
    s = (width - w_label - (n_clms - 1) * w_space) / n_clms  # size of each image
    height = h_title + s * n_rows + (n_rows - 1) * h_space  # total height (adjusted dynamically)

    # turn inches into fractions
    left = w_label / width
    right = 1.0
    top = 1.0 - h_title / height
    bottom = 0.0
    wspace = w_space / s  
    hspace = h_space / s
    #suptitle_x = (1 + left) / 2  # adjust suptitle x position to be centered against the grid
    #suptitle_y = 1.0 - (h_title / height / 2)  # position of the suptitle

    # titles for each column
    distances = config['noise']['target_corr_dists']
    titles = [rf'$d_H = {d}$' for d in distances]

    # plot each seed and image
    total = len(config['data']['image_names']) * len(config['noise']['noise_seeds'])
    for image_name, seed in tqdm.tqdm(product(config['data']['image_names'], config['noise']['noise_seeds']), total=total):
        # Load images for this seed: dict[layer][distance] = image
        recon_images = load_recon_images(exp_dir, config['layers'], distances, image_name, seed)

        fig = plt.figure(figsize=(width, height))
        #fig.suptitle(f'{model_name}, image={image_name}, seed={seed}', x=suptitle_x, y=suptitle_y)

        # GridSpec with custom margins
        gs = gridspec.GridSpec(
            n_rows, n_clms,
            figure=fig,
            left=left,
            right=right,
            top=top,
            bottom=bottom,
            wspace=wspace,
            hspace=hspace
        )

        for i, layer in enumerate(config['layers']):
            for j, dist in enumerate(distances):
                ax = fig.add_subplot(gs[i, j])
                img = recon_images[layer][dist]

                if img is None:
                    # If the image is None, create an empty image
                    img = Image.new('RGB', IMG_SIZE, color='gray')
                ax.imshow(img)

                # Hide ticks but keep labels
                ax.set_xticks([])
                ax.set_yticks([])
                ax.spines['left'].set_visible(False)
                ax.spines['bottom'].set_visible(False)

                # Add column titles to top row
                if i == 0:
                    ax.set_title(titles[j], fontsize=6)

                # Add layer labels to leftmost column
                if j == 0:
                    # show labels on the left side of the image
                    ax.text(
                        -0.1, 0.5, layer,
                        fontsize=6.5, 
                        va='center', ha='right',
                        transform=ax.transAxes,
                    )

        save_dir = os.path.join(exp_dir, 'summary', 'recon_images', image_name)
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, f'{seed}.png'), bbox_inches='tight', pad_inches=0, dpi=300)
        plt.savefig(os.path.join(save_dir, f'{seed}.pdf'), bbox_inches='tight', pad_inches=0, dpi=300)
        plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    # debug mode: only use one image
    parser.add_argument("--debug", action="store_true", help="Run in debug mode with a single image and seed.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    if args.debug:
        config['data']['image_names'] = config['data']['image_names'][:1]  # use only the first image
        config['noise']['noise_seeds'] = [0]  # use only the first seed

    # load and apply plt config
    plt_config_path = 'scripts/250527_readout_vision/configs/plt_config.yaml'
    with open(plt_config_path, 'r') as f:
        plt_config = yaml.safe_load(f)
    plt.rcParams.update(plt_config)

    main(config)