import argparse
import os
import random
from ast import literal_eval

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.cm import viridis

# Set style for publication-quality plots
plt.style.use("seaborn-v0_8-paper")
plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = ["Times New Roman"]
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
# Increase font sizes
plt.rcParams["axes.labelsize"] = 12
plt.rcParams["axes.titlesize"] = 14
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["legend.fontsize"] = 11
plt.rcParams["legend.title_fontsize"] = 12
# plt.style.use("ggplot")


def load_data(trainable_dir, config_id: int):
    """Load validation error data from a specific trainable layer directory."""
    csv_path = os.path.join(trainable_dir, "summary", "full.csv")
    if not os.path.exists(csv_path):
        print(f"File not found: {csv_path}")
        return None

    try:
        data = pd.read_csv(csv_path)

        # Sort by config id to ensure consistency
        data = data.sort_values("id")

        # Pick a specific configuration
        config_data = data[data["id"] == config_id]

        if config_data.empty:
            print(f"No data for config_id {config_id} in {csv_path}")
            return None

        # Extract validation error data
        row = config_data.iloc[0]
        col = (
            "extra.validation_losses"
            if "extra.validation_losses" in data.columns
            else "extra.validation_errors"
        )
        val_errors_str = row[col]

        # Clean the string format from "[0.123, 0.456, ...]" to a list of floats
        val_errors = literal_eval(val_errors_str)

        # print(f"Loaded data for {trainable_dir}, config {config_id}: {len(val_errors)} epochs")
        return val_errors
    except Exception as e:
        print(f"Error loading data from {csv_path} for config {config_id}: {e}")
        return None


def get_available_configs(base_dir):
    """Get a list of available config IDs from the first trainable directory."""
    # Try to find a directory with the training data
    for item in os.listdir(base_dir):
        if os.path.isdir(os.path.join(base_dir, item)) and "_trainable" in item:
            trainable_dir = os.path.join(base_dir, item)
            csv_path = os.path.join(trainable_dir, "summary", "full.csv")
            if os.path.exists(csv_path):
                try:
                    data = pd.read_csv(csv_path)
                    return data["id"].unique().tolist()
                except Exception as e:
                    print(f"Error reading configs from {csv_path}: {e}")
                    return []
    return []


def plot_validation_trajectories_multi(base_dir: str, n_configs: int, model_name: str):
    """Plot validation error trajectories for different numbers of trainable layers for multiple configs."""
    # Discover available trainable layer counts from the directory structure
    trainable_counts = []
    for item in os.listdir(base_dir):
        if os.path.isdir(os.path.join(base_dir, item)) and "_trainable" in item:
            try:
                count = int(item.split("_")[0])
                trainable_counts.append(count)
            except ValueError:
                continue

    # Sort the trainable counts for consistent plotting
    trainable_counts.sort()

    if not trainable_counts:
        print(f"No trainable layer directories found in {base_dir}")
        return

    # Get available config IDs
    available_configs = get_available_configs(base_dir)
    if not available_configs:
        print("No available configs found")
        return

    # Randomly sample n_configs from available configs
    if len(available_configs) > n_configs:
        sampled_configs = random.sample(available_configs, n_configs)
    else:
        sampled_configs = available_configs
        print(f"Only {len(sampled_configs)} configs available, using all of them")

    # Create a figure with n_configs subplots in a row
    fig, axes = plt.subplots(
        1, len(sampled_configs), figsize=(3 * len(sampled_configs), 3), squeeze=False
    )
    axes = axes[0]  # Since we only have one row

    # Get colors from viridis colormap
    colors = [viridis(i / max(1, len(trainable_counts) - 1)) for i in range(len(trainable_counts))]

    # For each config, create a subplot
    for i, config_id in enumerate(sampled_configs):
        ax = axes[i]

        # Store data for printing summary
        all_data = {}

        for j, count in enumerate(trainable_counts):
            trainable_dir = os.path.join(base_dir, f"{count}_trainable")
            val_errors = load_data(trainable_dir, config_id)

            if val_errors is not None:
                epochs = range(1, len(val_errors) + 1)
                ax.plot(
                    epochs,
                    val_errors,
                    marker="o",
                    markersize=4,
                    color=colors[j],
                    label=f"{count}",
                )
                all_data[count] = val_errors

        # Set up subplot
        ax.set_xlabel("Epoch")
        if i == 0:
            ax.set_ylabel("Validation Error")
        ax.set_title(f"Config {config_id}")
        ax.grid(True, linestyle="--", alpha=0.7)

    # Add a common legend at the right side of the figure
    handles, labels = axes[-1].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        # title="Trainable Layers",
        loc="center left",
        bbox_to_anchor=(1.01, 0.5),
        borderaxespad=0.0,
        fancybox=True,
    )

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)  # Make room for the legend at the bottom

    # Save the plot
    plt.savefig(
        f"plots/validation_trajectories_multi_{model_name}.pdf", dpi=500, bbox_inches="tight"
    )
    plt.close()


def plot_final_performance(base_dir: str, config_id: int):
    """Plot final validation error vs number of trainable layers."""

    trainable_counts = []
    final_errors = []

    for item in os.listdir(base_dir):
        if os.path.isdir(os.path.join(base_dir, item)) and "_trainable" in item:
            try:
                count = int(item.split("_")[0])
                trainable_dir = os.path.join(base_dir, item)
                val_errors = load_data(trainable_dir, config_id)

                if val_errors is not None:
                    trainable_counts.append(count)
                    final_errors.append(val_errors[-1])  # Last epoch's error
            except ValueError:
                continue

    if not trainable_counts:
        print(f"No data found for config_id {config_id}")
        return

    # Sort by trainable count for proper plotting
    indices = sorted(range(len(trainable_counts)), key=lambda i: trainable_counts[i])
    trainable_counts = [trainable_counts[i] for i in indices]
    final_errors = [final_errors[i] for i in indices]

    # Print data for reference
    print(f"\nFinal validation errors by trainable layer count for config_id {config_id}:")
    for i, count in enumerate(trainable_counts):
        print(f"{count} trainable layers: {final_errors[i]:.4f}")

    plt.figure(figsize=(6, 6))
    plt.plot(trainable_counts, final_errors, marker="o", linestyle="-", color=viridis(0.5))

    plt.xlabel("Number of Trainable Layers")
    plt.ylabel("Final Validation Error")
    plt.title(f"Final Validation Error vs Number of Trainable Layers (Config {config_id})")
    plt.xticks(trainable_counts)
    plt.grid(True, linestyle="--", alpha=0.7)

    # Save the plot
    plt.savefig(f"plots/final_performance_config_{config_id}.pdf", dpi=500, bbox_inches="tight")
    print(f"Saved plot to plots/final_performance_config_{config_id}.pdf")
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--n_configs",
        type=int,
        default=4,
        help="Number of random config IDs to sample for subplots",
    )
    args = parser.parse_args()
    repo_base_dir = "<repo base dir>"

    paths_and_seeds = [
        ("resnet18-full-grid/grid_search/c100/20_epochs", 42),
    ]
    for path, seed in paths_and_seeds:
        base_dir = os.path.join(repo_base_dir, path)

        # Set random seed for reproducibility
        random.seed(seed)
        model_name = path.split("/")[0]
        print(f"Plotting validation error trajectories for {args.n_configs} random configs...")
        plot_validation_trajectories_multi(base_dir, args.n_configs, model_name)

        # print(f"\nPlotting final performance relationship for config {args.config_id}...")
        # plot_final_performance(base_dir, args.config_id)

        # print("\nPlots saved to 'plots' directory.")
