import argparse
import os
from ast import literal_eval
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rich
import seaborn as sns
from scipy.stats import spearmanr

# Set style for publication-quality plots
plt.style.use("seaborn-v0_8-paper")
plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = ["Times New Roman"]
# plt.rcParams["axes.grid"] = True
# plt.rcParams["grid.alpha"] = 0.3
# Increase font sizes
plt.rcParams["axes.labelsize"] = 10
plt.rcParams["axes.titlesize"] = 11
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
plt.rcParams["legend.fontsize"] = 10


def load_validation_losses(base_dir, n_trainable, n_configs):
    """Load validation losses for a specific configuration."""
    report_path = os.path.join(base_dir, f"{n_trainable}_trainable", "summary", "full.csv")

    df = pd.read_csv(report_path)
    if "extra.validation_losses" in df.columns:
        df = df["extra.validation_losses"]
    else:
        df = df["extra.validation_errors"]

    df = df[:n_configs]
    df = df.apply(lambda x: literal_eval(x))
    return np.stack(df.to_numpy())


def find_max_trainable(base_dir):
    """Find the maximum number of trainable layers available in the directory."""
    max_trainable = 0
    for dirname in os.listdir(base_dir):
        if dirname.endswith("_trainable"):
            try:
                n_trainable = int(dirname.split("_")[0])
                max_trainable = max(max_trainable, n_trainable)
            except ValueError:
                continue

    rich.print(f"Found maximum trainable layers: {max_trainable}")
    return max_trainable


def compute_rank_correlation_matrix(
    base_dir: str,
    n_configs: int = 100,
    max_epochs: int = 20,
    max_fidelity: int = None,
):
    """
    Compute a matrix of Spearman's rank correlations between validation errors.

    Each entry (i,j) in the matrix represents the rank correlation between:
    - Configs at i epochs and j trainable layers
    - Configs at max epochs and max trainable layers (full fidelity)
    """
    rich.print(f"Base directory: {base_dir}")
    rich.print(f"Checking if directory exists: {os.path.exists(base_dir)}")

    # List available directories
    rich.print(f"Available directories: {os.listdir(base_dir)}")

    # Determine max_fidelity if not provided
    if max_fidelity is None:
        max_fidelity = find_max_trainable(base_dir)
        if max_fidelity == 0:
            rich.print("No trainable directories found. Defaulting to max_fidelity=5")
            max_fidelity = 5

    # Get full fidelity validation losses (at epoch 20 with max_fidelity trainable layers)
    val_losses = load_validation_losses(
        base_dir=base_dir, n_trainable=max_fidelity, n_configs=n_configs
    )
    full_fidelity_values = val_losses[:, -1]

    rich.print(f"Collected {len(full_fidelity_values)} full fidelity values")

    # Initialize correlation matrix
    correlation_matrix = np.zeros((max_epochs, max_fidelity))

    # Compute correlation for each combination of epochs and trainable layers
    for n_trainable in range(1, max_fidelity + 1):
        rich.print(f"Processing {n_trainable}_trainable directory...")
        trainable_dir = os.path.join(base_dir, f"{n_trainable}_trainable")
        if not os.path.exists(trainable_dir):
            rich.print(f"Directory not found: {trainable_dir}")
            continue

        for epoch in range(1, max_epochs + 1):
            epoch_values = []

            val_losses = load_validation_losses(
                base_dir=base_dir, n_trainable=n_trainable, n_configs=n_configs
            )
            epoch_values = val_losses[:, epoch - 1]

            correlation, _ = spearmanr(full_fidelity_values, epoch_values)
            correlation_matrix[epoch - 1, n_trainable - 1] = correlation

    return correlation_matrix


def plot_heatmap(correlation_matrix, output_path: str | Path):
    """Plot the correlation matrix as a heatmap."""
    rich.print("Creating heatmap...")
    plt.figure(figsize=(5, 5))

    # Create a mask for NaN values
    mask = np.isnan(correlation_matrix)

    # Create the heatmap
    ax = sns.heatmap(
        correlation_matrix,
        annot=True,
        fmt=".2f",
        cmap="viridis",
        mask=mask,
        square=True,
        vmin=0,
        vmax=1,
        cbar_kws={
            "label": "Spearman's $\\rho$",
            "shrink": 0.72,  # Make colorbar same height as the heatmap
        },
        annot_kws={"size": 7},  # Smaller text size for annotations
    )

    # Set labels and title
    ax.set_xlabel("# Trainable Layers")
    ax.set_ylabel("# Epochs")
    ax.set_title("ResNet-18 on CIFAR-100")

    # Set tick labels
    ax.set_xticks(np.arange(correlation_matrix.shape[1]) + 0.5)
    ax.set_xticklabels([f"{i + 1}" for i in range(correlation_matrix.shape[1])])
    ax.set_yticks(np.arange(correlation_matrix.shape[0]) + 0.5)
    ax.set_yticklabels([f"{i + 1}" for i in range(correlation_matrix.shape[0])])

    plt.tight_layout()

    rich.print(f"Saving heatmap to {output_path}")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, dpi=300, bbox_inches="tight")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_path", type=str, default=None)
    parser.add_argument("--model_name", type=str, default="resnet18")
    parser.add_argument("--max_fidelity", type=int, default=None)
    parser.add_argument("--max_epochs", type=int, default=20)
    parser.add_argument("--n_configs", type=int, default=100)
    parser.add_argument("--dataset", type=str, default="c10")
    args = parser.parse_args()

    # Get the absolute path to ensure proper file access
    script_dir = os.path.dirname(os.path.abspath(__file__))
    repo_dir = os.path.dirname(script_dir)
    base_dir = os.path.join(
        repo_dir, "output", args.model_name, "grid_search", args.dataset, "20_epochs"
    )

    rich.print(f"Script directory: {script_dir}")
    rich.print(f"Repository directory: {repo_dir}")
    rich.print(f"Base directory: {base_dir}")

    # Compute the correlation matrix
    correlation_matrix = compute_rank_correlation_matrix(
        base_dir=base_dir,
        n_configs=args.n_configs,
        max_epochs=args.max_epochs,
        max_fidelity=args.max_fidelity,
    )

    # Plot the heatmap
    plot_heatmap(
        correlation_matrix,
        output_path=os.path.join(
            repo_dir,
            "plots",
            "rank_correlation_heatmap",
            f"rank_correlation_heatmap_{args.n_configs}_{args.model_name}_{args.dataset}.pdf",
        ),
    )

    print("Heatmap created successfully.")
