import argparse
import os
import re
from ast import literal_eval

import matplotlib.pyplot as plt
import pandas as pd
import rich
import scienceplots  # noqa
from matplotlib.cm import viridis
from scipy.stats import spearmanr

# Set style for publication-quality plots
plt.style.use("science")
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
# Increase font sizes
plt.rcParams["axes.labelsize"] = 14
plt.rcParams["axes.titlesize"] = 14
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["legend.fontsize"] = 12
# Set marker size for all lines
plt.rcParams["lines.markersize"] = 5
plt.rcParams["lines.linewidth"] = 1.5


def load_results(results_dir, trainable_layers, epoch: int | None = None):
    """Loads results from full.csv for a given fidelity and epoch."""
    file_path = os.path.join(results_dir, f"{trainable_layers}_trainable", "summary", "full.csv")
    if not os.path.exists(file_path):
        print(f"Warning: File not found {file_path}")
        return None

    df = pd.read_csv(file_path)

    col_name = (
        "extra.validation_errors"
        if "extra.validation_errors" in df.columns
        else "extra.validation_losses"
    )
    if epoch is None:
        return df[col_name].apply(lambda x: literal_eval(x)[-1]).to_numpy()
    else:
        return df[col_name].apply(lambda x: literal_eval(x)[epoch - 1]).to_numpy()


def process_model(results_dir, model_name, epochs):
    """Process a single model and return its correlation data."""
    # Check if results_dir is a CSV file with pre-computed correlations
    if results_dir.endswith(".csv"):
        rich.print(f"Loading pre-computed correlations for {model_name} from {results_dir}")
        try:
            df = pd.read_csv(results_dir)

            # Check if the CSV has required columns
            if "spearman" not in df.columns:
                rich.print(f"Error: CSV file {results_dir} does not have 'spearman' column")
                return None, None

            # Get index column for layers/percentages
            index_col = df.columns[0]

            # Check if index column is numeric (layers) or already percentages
            try:
                # First try to interpret as layer numbers
                layers = df[index_col].astype(int).values
                is_percentage = False
            except (ValueError, TypeError):
                # If that fails, check if it might be percentage values
                try:
                    # Remove '%' character if present
                    percentages = (
                        df[index_col].astype(str).str.replace("%", "").astype(float).values
                    )
                    is_percentage = True
                except (ValueError, TypeError):
                    rich.print(
                        f"Error: Index column in {results_dir} cannot be interpreted as layers or percentages"
                    )
                    return None, None

            # Create dictionaries based on whether we have layers or percentages
            correlations = {}
            percentages = {}

            if is_percentage:
                for i, row in df.iterrows():
                    percent = (
                        float(row[index_col].replace("%", ""))
                        if isinstance(row[index_col], str)
                        else float(row[index_col])
                    )
                    # Use percentage as both keys since we don't have actual layer numbers
                    percentages[percent] = percent
                    correlations[percent] = row["spearman"]
            else:
                # Get the maximum layer value to calculate percentages
                max_layer = max(layers)

                for i, row in df.iterrows():
                    layer = int(row[index_col])
                    correlations[layer] = row["spearman"]
                    percentages[layer] = (layer / max_layer) * 100

            rich.print(f"Loaded {len(correlations)} correlation values for {model_name}")
            return percentages, correlations

        except Exception as e:
            rich.print(f"Error loading CSV file {results_dir}: {e}")
            return None, None

    # Regular directory processing for model results
    fidelities = []
    for item in os.listdir(results_dir):
        match = re.match(r"(\d+)_trainable", item)
        if match and os.path.isdir(os.path.join(results_dir, item)):
            fidelities.append(int(match.group(1)))

    if not fidelities:
        rich.print(f"Error: No '*_trainable' directories found in {results_dir}")
        return None, None

    fidelities.sort()
    full_fidelity_layers = fidelities[-1]

    rich.print(f"{model_name} - Full fidelity: {full_fidelity_layers} trainable layers")

    # Load full fidelity results
    full_results = load_results(results_dir, full_fidelity_layers, None)

    if full_results is None:
        rich.print(
            f"Error: Could not load results for {model_name} at full fidelity "
            f"({full_fidelity_layers} trainable layers)."
        )
        return None, None

    correlations = {}
    for layers in fidelities:
        rich.print(f"Processing {model_name} - {layers} trainable layers...")
        lower_results = load_results(results_dir, layers, epochs)
        if layers == full_fidelity_layers:
            lower_results = full_results

        # Calculate Spearman rank correlation
        rho, _ = spearmanr(full_results, lower_results)
        correlations[layers] = rho
        rich.print(f"  Rank Correlation (rho): {rho:.4f}")

    # Calculate percentage of trainable layers for each fidelity
    percentages = {}
    for layers in correlations:
        percentages[layers] = (layers / full_fidelity_layers) * 100

    return percentages, correlations


def main():
    parser = argparse.ArgumentParser(
        description="Plot rank correlation of HP configs across fidelities for multiple models."
    )
    parser.add_argument(
        "--models",
        type=str,
        required=True,
        help="Comma-separated list of model_name:path pairs (e.g., 'ResNet18:/path/to/resnet18,ViT:/path/to/vit'). "
        "Path can be either a directory containing *_trainable folders or a CSV file with pre-computed correlations.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        required=True,
        help="Number of epochs to consider for the metric (when processing directories, not used for CSV files).",
    )
    parser.add_argument(
        "--output_file", type=str, required=True, help="Path to save the output plot."
    )
    parser.add_argument(
        "--metric_col",
        type=str,
        default="objective_to_minimize",
        help="Column name in full.csv to use for ranking (e.g., 'objective_to_minimize' or 'extra.validation_errors').",
    )

    args = parser.parse_args()

    # Parse model name:path pairs
    model_paths = {}
    for model_pair in args.models.split(","):
        if ":" in model_pair:
            model_name, path = model_pair.strip().split(":", 1)
            model_paths[model_name] = path
        else:
            # If no name provided, use the directory name
            path = model_pair.strip()
            model_name = os.path.basename(os.path.normpath(path))
            model_paths[model_name] = path

    if not model_paths:
        rich.print("Error: No valid model paths provided.")
        return

    # Prepare figure
    plt.figure(figsize=(4, 2.4))

    # Get colors from viridis
    num_models = len(model_paths)
    colors = [viridis(i / (num_models - 1 if num_models > 1 else 1)) for i in range(num_models)]
    markers = ["o", "s", "^", "D", "v", "*", "X", "p"]  # Different markers for each model

    for i, (model_name, path) in enumerate(model_paths.items()):
        percentages, correlations = process_model(path, model_name, args.epochs)

        if percentages is None or correlations is None:
            rich.print(f"Skipping {model_name} due to errors.")
            continue

        # Sort by percentage for plotting
        sorted_layers = sorted(correlations.keys())
        x_values = [percentages[l] for l in sorted_layers]
        y_values = [correlations[l] for l in sorted_layers]

        # Plot this model with its own color and marker
        marker_idx = i % len(markers)
        plt.plot(
            x_values,
            y_values,
            marker=markers[marker_idx],
            linestyle="-",
            color=colors[i],
            label=model_name,
        )

    plt.xlabel(r"\% Trainable Layers")
    plt.ylabel("Spearman's $\\rho$")
    plt.title("Rank Correlation vs. Fidelity")

    plt.grid(True, linestyle="--", alpha=0.3)
    plt.ylim([0, 1.05])  # Ensure full correlation range is visible

    # Add legend only if there's more than one model
    if len(model_paths) > 1:
        plt.legend()

    # Use bbox_inches='tight' to prevent labels from being cut off
    plt.savefig(args.output_file, bbox_inches="tight", dpi=500)
    rich.print(f"Plot saved to {args.output_file}")


if __name__ == "__main__":
    main()
