"""
Plot quantitative results. Please run this after running calculate_size.py

Usage:
    python scripts/250527_readout_vision/plot_results.py /path/to/config.yaml
"""
from typing import Any
import argparse
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml


# TODO: import them from the common scripts
def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    # if alias is provided, use it
    if config['model'].get('model_alias'):
        return config['model']['model_alias']
    if config['model']['name'].endswith('-tfm'):
        # transformer model: use pretrained name by replacing '/' with '_'
        return config['model']['pretrained'].replace('/', '_')
    # otherwise, use the model name
    return config['model']['name']


def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vision', 'results', model_name, dataset_name, exp_name)


def assign_layer_colors(config: dict) -> dict:
    """
    Assign colors to each layer based on its depth in the model.

    Returns:
        dict[str, tuple]: layer name -> color (RGB tuple)
    """
    all_layers = list(config['layer_mapping'].keys())
    # if it is cnn model, count cnn layers
    if 'conv' in all_layers[0]:
        all_layers = [k for k in all_layers if 'conv' in k]

    n_layers = len(all_layers)
    layer_depth = {k: 100 * (i+1) / n_layers for i, k in enumerate(all_layers)}

    # assign colors to each layer based on depth
    cmap = plt.get_cmap('viridis')
    norm = plt.Normalize(vmin=0, vmax=100)
    layer_colors = {k: cmap(norm(layer_depth[k])) for k in all_layers}
    return layer_colors


def main(config: dict):
    # results path
    exp_dir = parse_output_dir(config)
    exp_db_path = os.path.join(exp_dir, 'experiment_db.json')
    size_dir = os.path.join(exp_dir, 'summary', 'representation_size')

    # load results
    results = pd.read_json(exp_db_path)
    size_agg = pd.read_csv(os.path.join(size_dir, 'size_aggregated.csv'))
    size_agg.index = size_agg['layer']  # set layer as index

    # save directory
    summary_dir = os.path.join(exp_dir, 'summary')
    os.makedirs(summary_dir, exist_ok=True)

    # config info
    layers = config['layers']  # layers used in the experiment
    model_name = parse_model_name(config)  # model name

    # assign colors to each layer
    # dict[layer_name, color]
    layer_colors = assign_layer_colors(config)

    ## 1. plot dh-dx graph
    n_col = 4
    n_row = len(layers) // n_col + (len(layers) % n_col > 0)
    dx = 'pixel_correlation_distance'
    dh = 'true_target_correlation_distance'

    ### 1.1 line plot
    fig, axes = plt.subplots(n_row, n_col, figsize=(5.5, 1.5 * n_row + 0.2), sharex=True, sharey=True)
    axes = axes.flatten()
    fig.suptitle(f'Feature correlation distance $d_h$ and pixel correlation distance $d_x$: {model_name}')
    for i, layer in enumerate(layers):
        ax = axes[i]
        layer_results = results[results['layer'] == layer]
        layer_results_mean = layer_results.groupby('target_corr_dist')[dx].mean()
        layer_results_std = layer_results.groupby('target_corr_dist')[dx].std()

        color = layer_colors[layer]
        ax.plot(layer_results_mean.index, layer_results_mean.values, color=color)
        ax.fill_between(layer_results_mean.index, 
                        layer_results_mean - layer_results_std, 
                        layer_results_mean + layer_results_std, 
                        color=color, alpha=0.2)
        ax.set_title(layer)

        # axis settings
        ax.set_xlim(-0.05, 1.05)
        ax.set_ylim(-0.05, 1.05)
        ax.set_xticks(np.arange(0, 1.1, 0.2))
        ax.set_yticks(np.arange(0, 1.1, 0.2))
        ax.set_aspect('equal', adjustable='box')
        if i % n_col == 0:   # y-label on the first column
            ax.set_ylabel('$d_x$')
        if i >= (n_row - 1) * n_col:  # x-label on the last row
            ax.set_xlabel('$d_h$')

    # save figure
    plt.tight_layout()
    os.makedirs(os.path.join(summary_dir, 'dx_dh'), exist_ok=True)
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance.png'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance.pdf'), dpi=300, bbox_inches='tight')

    ### 1.2 scatter plot
    fig, axes = plt.subplots(n_row, n_col, figsize=(5.5, 1.5 * n_row + 0.2), sharex=True, sharey=True)
    axes = axes.flatten()
    fig.suptitle(f'Feature correlation distance $d_h$ and pixel correlation distance $d_x$: {model_name}')
    for i, layer in enumerate(layers):
        ax = axes[i]
        layer_results = results[results['layer'] == layer]

        color = layer_colors[layer]
        ax.scatter(layer_results[dh], layer_results[dx], color=color, s=1, alpha=0.5)

        ax.set_title(layer)

        # axis settings
        ax.set_xlim(-0.05, 1.05)
        ax.set_ylim(-0.05, 1.05)
        ax.set_xticks(np.arange(0, 1.1, 0.2))
        ax.set_yticks(np.arange(0, 1.1, 0.2))
        ax.set_aspect('equal', adjustable='box')
        if i % n_col == 0:   # y-label on the first column
            ax.set_ylabel('$d_x$')
        if i >= (n_row - 1) * n_col:  # x-label on the last row
            ax.set_xlabel('$d_h$')

    # save figure
    plt.tight_layout()
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance_scatter.png'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance_scatter.pdf'), dpi=300, bbox_inches='tight')

    ## 2. representation size
    # plot representation size as a bar plot
    width = min(5.5, 2.5/12 * (len(layers) - 4) + 3)
    fig, ax = plt.subplots(figsize=(width, 3))
    size_agg = size_agg.loc[layers, :]  # sort by layers
    ax.bar(
        size_agg['layer'], 
        size_agg['size_mean'],
        yerr=size_agg['size_std'],
        color=[layer_colors[layer] for layer in size_agg['layer']],
        zorder=6
    )
    ax.set_ylim(0, 1.05)
    # grid
    ax.yaxis.grid(True, linestyle='-', alpha=0.5, zorder=0)
    ax.xaxis.grid(False)

    ax.set_title(f'Representation size: {model_name}')
    ax.set_ylabel('Representation size')
    ax.set_xlabel('Layer')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # save figure
    os.makedirs(os.path.join(summary_dir, 'representation_size'), exist_ok=True)
    fig.savefig(os.path.join(summary_dir, 'representation_size', 'representation_size.png'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(summary_dir, 'representation_size', 'representation_size.pdf'), dpi=300, bbox_inches='tight')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    # debug mode: only use one image
    parser.add_argument("--debug", action="store_true", help="Run in debug mode with a single image and seed.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    if args.debug:
        config['data']['image_names'] = config['data']['image_names'][:1]  # use only the first image
        config['noise']['noise_seeds'] = [0]  # use only the first seed

    # load and apply plt config
    plt_config_path = 'scripts/250527_readout_vision/configs/plt_config.yaml'
    with open(plt_config_path, 'r') as f:
        plt.rcParams.update(yaml.safe_load(f))

    main(config)