"""
Plot quantitative results.

Usage:
    python scripts/250610_readout_language/plot_results.py path/to/config.yaml
"""
from typing import Any
import argparse
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml


def parse_output_dir(config):
    """
    Parse the output directory for the experiment based on the configuration.
    """
    model_name = config['model']['pretrained'].replace('/', '_')
    dataset_name = config['data']['dataset_name']
    exp_name = config['exp_name']
    return os.path.join('output', 'readout_language', 'results', model_name, dataset_name, exp_name)


def main(config: dict):
    # load results
    output_dir = parse_output_dir(config)
    path = os.path.join(output_dir, 'experiment_db.json')
    results = pd.read_json(path)

    # assign colors to each layer
    # TODO: Temporarily hard-code layer depths
    layers = config['layer_indices']
    _depth = [0, 25, 50, 75]    
    cmap = plt.get_cmap('viridis')
    norm = plt.Normalize(vmin=0, vmax=75)
    layer_colors = {layers[i]: cmap(norm(_depth[i])) for i in range(len(layers))}

    # plot parameters
    n_clms = len(layers)
    dh = 'target_corr_dist'
    dx = 'top_1_error_rate_normal_tokens'

    # line plot
    fig, axes = plt.subplots(1, n_clms, figsize=(5.5, 1.7), sharex=True, sharey=True)
    fig.suptitle(config['model']['pretrained'])
    for i, layer in enumerate(layers):
        ax = axes[i]
        ax.set_title(f'layer_{layer}', fontsize=8)
        ax.set_ylim(-0.05, 1.05)
        ax.set_xlim(-0.05, 1.05)
        ax.set_xticks(np.arange(0, 1.1, 0.2))
        ax.set_yticks(np.arange(0, 1.1, 0.2))
        ax.set_xlabel('$d_H$')
        if i == 0:
            ax.set_ylabel('$d_X$')
        ax.set_aspect('equal', adjustable='box')

        layer_results = results[results['layer_idx'] == layer]
        mean = layer_results.groupby('target_corr_dist')[dx].mean()
        std = layer_results.groupby('target_corr_dist')[dx].std()
        ax.plot(mean.index, mean.values, color=layer_colors[layer])
        ax.fill_between(
            mean.index,
            mean - std,
            mean + std,
            color=layer_colors[layer],
            alpha=0.2,
        )
    
    # save figure
    plt.tight_layout()
    save_dir = os.path.join(output_dir, 'summary', f'dh_dx')
    os.makedirs(save_dir, exist_ok=True)
    fig.savefig(os.path.join(save_dir, f'{dh}_{dx}_line.png'))
    fig.savefig(os.path.join(save_dir, f'{dh}_{dx}_line.pdf'))

    # scatter plot
    fig, axes = plt.subplots(1, n_clms, figsize=(5.5, 1.7), sharex=True, sharey=True)
    fig.suptitle(config['model']['pretrained'])
    for i, layer in enumerate(layers):
        ax = axes[i]
        ax.set_title(f'layer_{layer}', fontsize=8)
        ax.set_ylim(-0.05, 1.05)
        ax.set_xlim(-0.05, 1.05)
        ax.set_xticks(np.arange(0, 1.1, 0.2))
        ax.set_yticks(np.arange(0, 1.1, 0.2))
        ax.set_xlabel('$d_H$')
        if i == 0:
            ax.set_ylabel('$d_X$')
        ax.set_aspect('equal', adjustable='box')

        layer_results = results[results['layer_idx'] == layer]
        ax.scatter(
            layer_results[dh], layer_results[dx], color=layer_colors[layer], s=2, alpha=0.75
        )

    # save figure
    plt.tight_layout()
    save_dir = os.path.join(output_dir, 'summary', f'dh_dx')
    os.makedirs(save_dir, exist_ok=True)
    fig.savefig(os.path.join(save_dir, f'{dh}_{dx}_scatter.png'))
    fig.savefig(os.path.join(save_dir, f'{dh}_{dx}_scatter.pdf'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct texts from the original and perturbed features.")
    parser.add_argument("config_path", type=str, help="Path to the configuration file.")
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)
    
    # load plt config
    with open('scripts/250527_readout_vision/configs/plt_config.yaml', 'r') as f:
        plt.rcParams.update(yaml.safe_load(f))
    # override ticklabel fontsize
    plt.rcParams.update({'xtick.labelsize': 6, 'ytick.labelsize': 6})

    main(config)