"""
Plot quantitative results. Please run this after running calculate_size.py

Usage:
    python scripts/250527_readout_vision/plot_results.py /path/to/config.yaml
"""
from typing import Any
import argparse
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml


def parse_model_name(config) -> str:
    """
    Parse model name for the output directory.
    """
    return config['model']['pretrained'].replace('/', '_')
    

def parse_output_dir(config: dict[str, Any]) -> str:
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = parse_model_name(config)
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_vae', 'results', model_name, dataset_name, exp_name)


def get_color():
    cmap = plt.get_cmap('viridis')
    norm = plt.Normalize(vmin=0, vmax=100)
    return cmap(norm(0))


def main(config: dict):
    # results path
    exp_dir = parse_output_dir(config)
    exp_db_path = os.path.join(exp_dir, 'experiment_db.json')
    results = pd.read_json(exp_db_path)

    # save directory
    summary_dir = os.path.join(exp_dir, 'summary')
    os.makedirs(summary_dir, exist_ok=True)

    # config info
    model_name = config['model']['pretrained']

    # get color
    color = get_color()
    
    ## 1. plot dh-dx graph
    dx = 'pixel_correlation_distance'
    dh = 'target_corr'

    ### 1.1 line plot
    fig, ax = plt.subplots(1, 1, figsize=(1.7, 1.7), sharex=True, sharey=True)
    mean = results.groupby('target_corr')[dx].mean()
    std = results.groupby('target_corr')[dx].std()
    ax.plot(mean.index, mean.values, color=color)
    ax.fill_between(mean.index, 
                    mean - std, 
                    mean + std, 
                    color=color, alpha=0.2)

    # axis settings
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xticks(np.arange(0, 1.1, 0.2))
    ax.set_yticks(np.arange(0, 1.1, 0.2))
    ax.set_aspect('equal', adjustable='box')
    ax.set_ylabel('$d_X$')
    ax.set_xlabel('$d_H$')
    ax.set_title(f'{model_name}')

    # save figure
    plt.tight_layout()
    os.makedirs(os.path.join(summary_dir, 'dx_dh'), exist_ok=True)
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance.png'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance.pdf'), dpi=300, bbox_inches='tight')

    ### 1.2 scatter plot
    fig, ax = plt.subplots(1, 1, figsize=(1.7, 1.7), sharex=True, sharey=True)
    ax.scatter(
        results[dh], results[dx], color=color, alpha=0.5, s=1
    )
    # axis settings
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xticks(np.arange(0, 1.1, 0.2))
    ax.set_yticks(np.arange(0, 1.1, 0.2))
    ax.set_aspect('equal', adjustable='box')
    ax.set_ylabel('$d_X$')
    ax.set_xlabel('$d_H$')
    ax.set_title(f'{model_name}')

    # save figure
    plt.tight_layout()
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance_scatter.png'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(summary_dir, 'dx_dh', 'correlation_distance_scatter.pdf'), dpi=300, bbox_inches='tight')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    # debug mode: only use one image
    parser.add_argument("--debug", action="store_true", help="Run in debug mode with a single image and seed.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)

    if 'image_names' not in config['data']:
        # use image_names_path
        with open(config['data']['image_names_path'], 'r') as f:
            config['data']['image_names'] = yaml.safe_load(f)

    if args.debug:
        config['data']['image_names'] = config['data']['image_names'][:1]  # use only the first image
        config['noise']['noise_seeds'] = [0]  # use only the first seed

    # load and apply plt config
    plt_config_path = 'scripts/250527_readout_vision/configs/plt_config.yaml'
    with open(plt_config_path, 'r') as f:
        plt.rcParams.update(yaml.safe_load(f))

    main(config)