import argparse
import os
from pathlib import Path

import matplotlib.pyplot as plt
import mpl_toolkits.axisartist as AA
import pandas as pd
import scienceplots  # noqa
from matplotlib.cm import viridis
from mpl_toolkits.axes_grid1 import host_subplot

# Set style for publication-quality plots
plt.style.use("science")
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
# Increase font sizes
plt.rcParams["axes.labelsize"] = 12
plt.rcParams["axes.titlesize"] = 14
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["legend.fontsize"] = 12

# Set marker size for all lines
plt.rcParams["lines.markersize"] = 5

# Get colors from the viridis colormap
COLORS = {
    "validation": viridis(0.1),  # Purple
    "runtime": viridis(0.7),  # Teal for cumulative runtime
    "memory": viridis(0.3),  # Yellow for memory
}


def load_hw_metrics(csv_path):
    """
    Load hardware metrics from CSV file.

    Args:
        csv_path: Path to the CSV file with hardware metrics
    """
    df = pd.read_csv(csv_path)

    # Calculate total runtime (in ms)
    # df["total_runtime"] = df["median_fwd_time"] + df["median_bwd_time"] + df["optimizer_time"]
    df["total_runtime"] = df["median_loop_time"]

    return df


def load_validation_data(experiment_dir, config_id, epoch):
    """
    Load validation data for a specific config across different fidelity levels.

    Args:
        experiment_dir: Base directory containing the experiment results
        epoch: Specific epoch at which to plot validation results (0-indexed)
        optimizer_name: Optional filter for optimizer name
        learning_rate: Optional filter for learning rate
        weight_decay: Optional filter for weight decay
    """
    base_path = Path(experiment_dir)
    fidelity_dirs = sorted(
        [d for d in base_path.iterdir() if d.is_dir() and d.name.endswith("_trainable")],
        key=lambda x: int(x.name.split("_")[0]),
    )

    all_data = []

    for fidelity_dir in fidelity_dirs:
        n_trainable = int(fidelity_dir.name.split("_")[0])

        # Load summary data
        short_csv = fidelity_dir / "summary" / "short.csv"
        full_csv = fidelity_dir / "summary" / "full.csv"

        # Load data from CSV files - read short.csv without setting index
        short_df = pd.read_csv(short_csv)
        full_df = pd.read_csv(full_csv)

        # Find the corresponding row in full_df
        config_rows = full_df[full_df["id"] == config_id]
        if config_rows.empty:
            print(f"Warning: Could not find configuration with ID {config_id} in {full_csv}")
            continue

        config = config_rows.iloc[0]

        # Extract data
        validation_errors = eval(config["extra.validation_errors"])

        # Select specified epoch or use the last one
        epoch_idx = epoch if epoch is not None else len(validation_errors) - 1
        if epoch_idx >= len(validation_errors) or epoch_idx < 0:
            print(
                f"Warning: Epoch {epoch_idx} is out of bounds for {fidelity_dir}, using last epoch"
            )
            epoch_idx = len(validation_errors) - 1

        validation_error = validation_errors[epoch_idx]

        row = {
            "n_trainable": n_trainable,
            "validation_error": validation_error,
            "config_id": config_id,
            "optimizer": config["config.optimizer_name"],
            "learning_rate": config["config.learning_rate"],
            "weight_decay": config["config.weight_decay"],
        }
        all_data.append(row)

    if not all_data:
        raise ValueError("No data found that matches the specified criteria")

    return pd.DataFrame(all_data)


def plot_validation_vs_runtime(
    validation_df,
    hw_metrics_df,
    output_dir,
    config_id,
    metric="runtime",
    alpha=0.7,
    grid_alpha=0.3,
    show_x_label=False,
    show_title=False,
    title: str | None = None,
):
    """
    Plot validation error and either runtime or memory against number of trainable layers.

    Args:
        validation_df: DataFrame containing validation data
        hw_metrics_df: DataFrame containing hardware metrics
        output_dir: Directory to save the plot
        config_id: ID of the configuration to plot
        metric: Which metric to plot on the second y-axis ("runtime" or "memory")
        alpha: Transparency level for plot elements (0-1)
        grid_alpha: Transparency level for grid lines (0-1)
        show_x_label: Whether to show the x-axis label
    """
    # Merge the datasets on n_trainable
    df = pd.merge(validation_df, hw_metrics_df, on="n_trainable", how="inner")

    if df.empty:
        raise ValueError("No matching data between validation and hardware metrics")

    # Create the plot with two y-axes using host_subplot
    fig = plt.figure(figsize=(4, 2.4))  # Increase figure width for more space

    # Setting rcParams won't affect host_subplot properly
    # We need to apply the grid alpha directly to the axes

    host = host_subplot(111, axes_class=AA.Axes)
    plt.subplots_adjust(right=0.85)

    par1 = host.twinx()

    # Create new fixed axis with specific offset
    new_fixed_axis1 = par1.get_grid_helper().new_fixed_axis
    par1.axis["right"] = new_fixed_axis1(loc="right", axes=par1, offset=(0, 0))

    # Make sure all axis spines are properly configured
    host.axis["right"].set_visible(False)
    host.axis["top"].set_visible(False)
    par1.axis["right"].set_visible(True)

    # Set grid properties - this is the key change
    host.grid(True, alpha=grid_alpha)

    # Plot validation error on first y-axis (host)
    (p1,) = host.plot(
        df["n_trainable"],
        df["validation_error"],
        marker="o",
        color=COLORS["validation"],
        label="Validation Error",
        alpha=alpha,
    )

    # Plot selected metric on second y-axis (par1)
    if metric == "runtime":
        # Calculate runtime as percentage of full model runtime
        baseline_runtime = df["total_runtime"].iloc[-1]  # Full model runtime (highest n_trainable)
        runtime_percentage = (df["total_runtime"] / baseline_runtime) * 100

        # Use bar plot instead of line plot for runtime
        p2 = par1.bar(
            df["n_trainable"],
            runtime_percentage,
            color=COLORS["runtime"],
            alpha=alpha,
            width=0.6,  # Width of the bars
            label="Relative Runtime",
        )
        ylabel = r"Runtime (\% of full model)"

        # Set y-axis limits from min percentage to 100%
        min_percentage = runtime_percentage.min()
        par1.set_ylim(min_percentage * 0.95, 105)  # Add some padding at bottom and top
    else:  # memory
        # Calculate memory as percentage of full model memory
        baseline_memory = df["memory"].iloc[-1]  # Full model memory (highest n_trainable)
        memory_percentage = (df["memory"] / baseline_memory) * 100

        # Use bar plot for memory, similar to runtime
        p2 = par1.bar(
            df["n_trainable"],
            memory_percentage,
            color=COLORS["memory"],
            alpha=alpha,
            width=0.6,  # Width of the bars
            label="Relative Memory",
        )
        ylabel = r"Memory (\% of full model)"

        # Set y-axis limits from min percentage to 100%
        min_percentage = memory_percentage.min()
        par1.set_ylim(min_percentage * 0.95, 105)  # Add some padding at bottom and top

    # Set labels and title
    if show_x_label:
        host.set_xlabel(r"\# Trainable Layers")
    else:
        host.set_xlabel("")
    host.set_ylabel("Validation Error", color=COLORS["validation"])
    par1.set_ylabel(ylabel, color=COLORS["runtime"] if metric == "runtime" else COLORS["memory"])

    # Set colors for y-axis labels
    host.yaxis.label.set_color(p1.get_color())
    par1.yaxis.label.set_color(COLORS["runtime"] if metric == "runtime" else COLORS["memory"])

    # Set colors for y-axis ticks
    tkw = {"size": 4, "width": 1.5}
    host.tick_params(axis="y", colors=p1.get_color(), **tkw)
    par1.tick_params(
        axis="y", colors=COLORS["runtime"] if metric == "runtime" else COLORS["memory"], **tkw
    )

    # Get configuration details for title
    if show_title:
        host.set_title(title)

    # Ensure directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Save plot
    filename = f"validation_vs_{metric}_config_{config_id}.pdf"
    filepath = os.path.join(output_dir, "val_vs_runtime", filename)
    plt.savefig(filepath, bbox_inches="tight", dpi=500)

    print(f"Plot saved to {filepath}")
    plt.close()


def parse_args():
    parser = argparse.ArgumentParser(
        description="Plot validation error and runtime vs fidelity level"
    )
    parser.add_argument(
        "--experiment_dir",
        help="Directory containing experiment data (e.g., output/resnet18-full-grid/grid_search/c100/20_epochs)",
    )
    parser.add_argument(
        "--direct_csv",
        help="Direct CSV file containing metrics data with columns: n_trainable, val_loss, median_loop_time",
    )
    parser.add_argument(
        "--hw_metrics",
        default="output/model_perf/csv_data/resnet18_bs1024.csv",
        help="Path to the CSV file with hardware metrics (not used when --direct_csv is provided)",
    )
    parser.add_argument(
        "--epoch",
        type=int,
        default=None,
        help="Epoch at which to plot validation results (0-indexed). If not specified, uses the final epoch.",
    )
    parser.add_argument("--output_dir", default="plots", help="Directory to save the plots")
    parser.add_argument(
        "--dataset",
        default=None,
        help="Dataset name for plot title. If not provided, it will be inferred from the experiment directory.",
    )
    parser.add_argument("--config_id", type=int, default=0, help="Config ID to plot")
    parser.add_argument(
        "--metric",
        choices=["runtime", "memory"],
        default="runtime",
        help="Which metric to plot on the second y-axis (runtime or memory)",
    )
    parser.add_argument(
        "--alpha",
        type=float,
        default=1,
        help="Transparency level for plot elements (0-1)",
    )
    parser.add_argument(
        "--grid_alpha",
        type=float,
        default=0.3,
        help="Transparency level for grid lines (0-1)",
    )
    parser.add_argument(
        "--show_x_label",
        action="store_true",
        default=False,
        help="Whether to show the x-axis label",
    )
    parser.add_argument(
        "--show_title",
        action="store_true",
        default=False,
        help="Whether to show the title",
    )
    parser.add_argument(
        "--title",
        type=str,
        default=None,
        help="Title for the plot",
    )
    return parser.parse_args()


def load_direct_csv(csv_path):
    """
    Load data directly from a CSV file with metrics.

    Args:
        csv_path: Path to the CSV file containing metrics data
    """
    df = pd.read_csv(csv_path)

    # Create validation dataframe with required columns
    validation_df = pd.DataFrame()
    validation_df["n_trainable"] = df["n_trainable"]
    validation_df["validation_error"] = df["val_loss"]
    validation_df["config_id"] = 0
    validation_df["optimizer"] = "unknown"
    validation_df["learning_rate"] = 0.0
    validation_df["weight_decay"] = 0.0

    # Create hardware metrics dataframe with required columns
    hw_metrics_df = pd.DataFrame()
    hw_metrics_df["n_trainable"] = df["n_trainable"]
    hw_metrics_df["total_runtime"] = df["median_loop_time"]

    return validation_df, hw_metrics_df


def main():
    # Parse command-line arguments
    args = parse_args()

    # Validate input arguments
    if not args.experiment_dir and not args.direct_csv:
        raise ValueError("Either --experiment_dir or --direct_csv must be provided")
    if args.experiment_dir and args.direct_csv:
        raise ValueError("Cannot provide both --experiment_dir and --direct_csv")

    # Infer dataset name if not provided
    dataset_name = args.dataset
    if dataset_name is None:
        if args.experiment_dir:
            # Try to infer dataset from experiment directory
            exp_path = Path(args.experiment_dir)
            if "grid_search" in str(exp_path):
                try:
                    # Extract dataset from path like output/resnet18-full-grid/grid_search/c100/20_epochs
                    parts = str(exp_path).split("/")
                    grid_idx = parts.index("grid_search")
                    if grid_idx + 1 < len(parts):
                        dataset_name = parts[grid_idx + 1].upper()
                except (ValueError, IndexError):
                    dataset_name = "Unknown"
        else:
            # For direct CSV, use a default name
            dataset_name = "Direct CSV"

        if dataset_name is None:
            dataset_name = "Unknown"

    # Print info about what we're doing
    if args.experiment_dir:
        print(f"Loading validation data from {args.experiment_dir}")
        print(f"Loading hardware metrics from {args.hw_metrics}")
    else:
        print(f"Loading data from direct CSV: {args.direct_csv}")

    print(f"Dataset: {dataset_name}")
    print(f"Epoch: {args.epoch if args.epoch is not None else 'Last'}")
    print(f"Metric: {args.metric}")
    print(f"Alpha: {args.alpha}")
    print(f"Grid Alpha: {args.grid_alpha}")
    print(f"Show X Label: {args.show_x_label}")
    print(f"Show Title: {args.show_title}")

    # Load data
    if args.experiment_dir:
        hw_metrics_df = load_hw_metrics(args.hw_metrics)
        validation_df = load_validation_data(
            args.experiment_dir,
            config_id=args.config_id,
            epoch=args.epoch,
        )
    else:
        # For direct CSV, extract both validation and hardware metrics from the same file
        validation_df, hw_metrics_df = load_direct_csv(args.direct_csv)

        # Add memory column if needed (for memory plot)
        if args.metric == "memory" and "memory" not in hw_metrics_df.columns:
            print(
                "Note: No memory data available in CSV. Using runtime values as placeholders for memory."
            )
            hw_metrics_df["memory"] = hw_metrics_df["total_runtime"]

    # Generate plots
    plot_validation_vs_runtime(
        validation_df,
        hw_metrics_df,
        args.output_dir,
        config_id=args.config_id,
        metric=args.metric,
        alpha=args.alpha,
        grid_alpha=args.grid_alpha,
        show_x_label=args.show_x_label,
        show_title=args.show_title,
        title=args.title,
    )


if __name__ == "__main__":
    main()
