import argparse

import numpy as np
import pandas as pd
import seaborn as sns
import yaml
from matplotlib import pyplot as plt
from scipy.stats import spearmanr

from freezes import CANVAS_BASE_PATH, ExpCanvas


def get_correlations_for_fidelity_range(
    data_dict: dict[int, pd.DataFrame],
    y_metric: str = "objective_to_minimize",
) -> np.ndarray:
    corr_mat = np.zeros((len(data_dict), len(data_dict))).tolist()
    p_value_mat = np.zeros((len(data_dict), len(data_dict))).tolist()
    for _fid1, df1 in data_dict.items():
        for _fid2, df2 in data_dict.items():
            if _fid1 > _fid2:
                continue
            if len(df1[y_metric]) != len(df2[y_metric]):
                print(f"Dataframes for fidelity {_fid1} and {_fid2} have different lengths.")
                continue
            correlation, pvalue = spearmanr(df1[y_metric], df2[y_metric])
            corr_mat[_fid1][_fid2] = correlation
            p_value_mat[_fid1][_fid2] = pvalue

    return corr_mat, p_value_mat


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()

    parser.add_argument("--canvas_access", type=str, default="global-meta")
    parser.add_argument("--fid_range", type=int, nargs="+", default=[1, 9])
    parser.add_argument("--hpo_config", type=str, default="s1_grid_search_config")
    parser.add_argument("--exp_group", type=str, default="s1_grid_search_neps_0.12.2")
    parser.add_argument("--y_metric", type=str, default="objective_to_minimize")
    parser.add_argument("--cost_metric", type=str, default="cost")

    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    exp_canvas = ExpCanvas(CANVAS_BASE_PATH, args.canvas_access)

    # Load grid search config for hyperparameter list
    with open(exp_canvas.freezes_root / "configs" / "neps" / f"{args.hpo_config}.yaml", "r") as f:
        grid_search_config = yaml.safe_load(f)
    hps = [f"config.{k}" for k in grid_search_config["pipeline_space"]]
    metrics = [args.y_metric, args.cost_metric]

    # Load HPO run data loader
    def get_df(x):
        return pd.read_csv(
            exp_canvas.results_root / args.exp_group / f"train_{x}" / "summary" / "full.csv"
        )

    # Get fidelity range
    fid_range = list(range(args.fid_range[0], args.fid_range[1] + 1))
    if 0 not in fid_range:
        fid_range.insert(0, 0)

    # Collecting data
    data_dict = {}
    for fid in fid_range:
        print(f"Loading data for fidelity {fid}")
        _hps = hps.copy()
        _fid = fid
        if fid == 0:
            _fid = "all"
            _hps.remove("config.layers_to_train")
        try:
            df = get_df(_fid)
            data_dict[fid] = df[_hps + metrics].dropna().sort_values(by=_hps)
        except FileNotFoundError:
            print(f"Data for fidelity {fid} not found. Skipping...")

    print("Calculating correlations...")
    corr_mat, p_value_mat = get_correlations_for_fidelity_range(data_dict, args.y_metric)

    # Save correlation matrix
    plt.clf()
    heatmap = sns.heatmap(corr_mat, annot=True, cmap="YlGnBu", fmt=".2f")
    heatmap.set_title("LM Hyperparameter Rank Correlation")
    plt.savefig(exp_canvas.plot_root / "correlation_matrix.png", dpi=150, bbox_inches="tight")
    print(f"Saving correlation matrix to {exp_canvas.plot_root / 'correlation_matrix.png'}")
# end of file
