import argparse
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.cm import viridis

# Set style for publication-quality plots
plt.style.use("seaborn-v0_8-paper")
plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = ["Times New Roman"]
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.1
# Increase font sizes
plt.rcParams["axes.labelsize"] = 10
plt.rcParams["axes.titlesize"] = 11
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
plt.rcParams["legend.fontsize"] = 10
# plt.style.use("ggplot")


# Set marker size for all lines
plt.rcParams["lines.markersize"] = 5

# Set legend to have 4 columns
LEGEND_NCOL = 4

MODEL_NAMES = {
    "gpt_scales_14m_tied": "GPT-2 (14M)",
    "gpt_scales_14m_untied": "GPT-2 (14M)",
    "gpt_scales_l24_tied": "GPT-2 (127M)",
    "gpt_scales_l24_untied": "GPT-2 (127M)",
    "gpt_scales_pythia_1.4b_tied": "Pythia (1.4B)",
    "gpt_scales_pythia_1.4b_untied": "Pythia (1.4B)",
    "resnet101": "ResNet-101",
}

# Get colors from the viridis colormap
COLORS = {
    "fwd": viridis(0.1),  # Purple
    "bwd": viridis(0.4),  # Blue
    "opt": viridis(0.7),  # Green
    "mem": viridis(0.4),  # Yellow
    "runtime": viridis(0.1),  # Teal for cumulative runtime
}


def load_model_data_with_batch_sizes(model_batch_sizes):
    """Load data for specified models with their corresponding batch sizes.

    Args:
        model_batch_sizes: Dictionary mapping model names to batch sizes
    """
    data_dir = Path("output/model_perf/csv_data")
    models = []

    for model_name, batch_size in model_batch_sizes.items():
        file_path = data_dir / f"{model_name}_bs{batch_size}.csv"
        if file_path.exists():
            df = pd.read_csv(file_path)
            df["model_name"] = model_name
            df["batch_size"] = batch_size
            # Calculate total runtime
            df["total_runtime"] = (
                df["median_fwd_time"] + df["median_bwd_time"] + df["optimizer_time"]
            )
            models.append(df)
        else:
            print(f"Warning: File not found for model {model_name} with batch size {batch_size}")

    if not models:
        raise ValueError("No data found for the specified models with their batch sizes")

    return pd.concat(models, ignore_index=True)


def plot_model_metrics(
    df, layout="row", show_cumulative=False, show_legend=True, show_x_label=True
):
    """Create subplots for each model with dual y-axes showing all metrics.

    Args:
        df: DataFrame containing model data
        layout: 'row' for horizontal layout, 'column' for vertical layout
        show_cumulative: If True, shows cumulative runtime instead of separate timings
        show_legend: If True, shows the legend
        show_x_label: If True, shows the x-axis label
    """
    model_names = df["model_name"].unique()
    batch_sizes = {
        model: df[df["model_name"] == model]["batch_size"].iloc[0] for model in model_names
    }
    num_models = len(model_names)

    # Make all subplots square
    subplot_size = 3  # Size in inches for each square subplot

    if layout == "row":
        fig_width = subplot_size * num_models
        fig_height = subplot_size
        fig, axes = plt.subplots(1, num_models, figsize=(fig_width, fig_height))
    else:  # column layout
        fig_width = subplot_size
        fig_height = subplot_size * num_models
        fig, axes = plt.subplots(num_models, 1, figsize=(fig_width, fig_height))

    # Ensure axes is always an array even with single subplot
    if num_models == 1:
        axes = [axes]

    # For storing legend handles and labels
    all_handles = []
    all_labels = []

    for i, model_name in enumerate(model_names):
        ax = axes[i]
        model_data = df[df["model_name"] == model_name]
        batch_size = batch_sizes[model_name]

        # Create second y-axis
        ax2 = ax.twinx()

        if show_cumulative:
            # Plot only cumulative runtime
            (runtime_line,) = ax.plot(
                model_data["n_trainable"],
                model_data["total_runtime"],
                marker="o",
                color=COLORS["runtime"],
                label="Total Runtime",
            )

            # Plot memory on right y-axis
            (mem_line,) = ax2.plot(
                model_data["n_trainable"],
                model_data["memory"],
                marker="D",
                color=COLORS["mem"],
                label="Memory",
            )

            # Store handles and labels only once (from the first subplot)
            if i == 0:
                all_handles = [runtime_line, mem_line]
                all_labels = ["Total Runtime", "Memory"]

        else:
            # Plot timing metrics on left y-axis
            (fwd_line,) = ax.plot(
                model_data["n_trainable"],
                model_data["median_fwd_time"],
                marker="o",
                color=COLORS["fwd"],
                label="Forward",
            )
            (bwd_line,) = ax.plot(
                model_data["n_trainable"],
                model_data["median_bwd_time"],
                marker="s",
                color=COLORS["bwd"],
                label="Backward",
            )
            (opt_line,) = ax.plot(
                model_data["n_trainable"],
                model_data["optimizer_time"],
                marker="^",
                color=COLORS["opt"],
                label="Optimizer",
            )

            # Plot memory on right y-axis
            (mem_line,) = ax2.plot(
                model_data["n_trainable"],
                model_data["memory"],
                marker="D",
                color=COLORS["mem"],
                label="Memory",
            )

            # Store handles and labels only once (from the first subplot)
            if i == 0:
                all_handles = [fwd_line, bwd_line, opt_line, mem_line]
                all_labels = ["Forward", "Backward", "Optimizer", "Memory"]

        if show_x_label and ((i == num_models - 1 and layout == "column") or layout == "row"):
            ax.set_xlabel("# Trainable Layers")

        if i == 0 and layout == "row":
            ax.set_ylabel("Time (ms)", color="black" if not show_cumulative else COLORS["runtime"])
            ax2.set_ylabel("")
        elif i == num_models - 1 and layout == "row":
            ax.set_ylabel("")

        ax2.set_ylabel("Memory (MB)", color=COLORS["mem"])
        # Format model name for display
        base_model_name = model_name.split("_tie")[0]
        display_name = MODEL_NAMES.get(model_name, base_model_name.replace("_", " ").title())

        display_name = f"{display_name} (bs={batch_size})"

        # Set labels and title
        ax.set_title(display_name)

        # Set colors for y-axis labels
        ax.tick_params(axis="y", labelcolor="black" if not show_cumulative else COLORS["runtime"])
        ax2.tick_params(axis="y", labelcolor=COLORS["mem"])

        # Set y-axis limits to start from 0
        ax.set_ylim(bottom=0)
        ax2.set_ylim(bottom=0)

        # Force the aspect ratio to be equal, making the subplot square in data space
        # This sets the aspect ratio of the data, but we still need to make sure the figure itself is square
        ax.set_box_aspect(1)

    # Add legend if requested
    if show_legend:
        # Adjust legend columns based on number of items
        legend_ncol = 2 if show_cumulative else LEGEND_NCOL

        # Add a single shared legend at the bottom with cleaner styling
        fig.tight_layout(rect=[0, 0.08, 1, 0.98])  # Make space for the legend
        fig.legend(
            all_handles,
            all_labels,
            loc="lower center",
            bbox_to_anchor=(0.5, -0.05),
            handlelength=0.05,  # Reduce the length of the legend lines
            handletextpad=1,  # Reduce space between the legend lines and text
            ncol=legend_ncol,
            frameon=False,  # No frame
            markerscale=1.3,
        )
    else:
        fig.tight_layout()

    # Create filename with model names and batch sizes
    filename_parts = []
    for model_name in model_names:
        batch_size = batch_sizes[model_name]
        filename_parts.append(f"{model_name}_bs{batch_size}")

    os.makedirs("plots/hw_metrics", exist_ok=True)
    plt.savefig(
        f"plots/hw_metrics/metrics_{'_'.join(filename_parts)}.pdf",
        bbox_inches="tight",
        dpi=500,
    )
    plt.close()


def parse_args():
    parser = argparse.ArgumentParser(description="Plot hardware metrics for models")
    parser.add_argument(
        "--model_batch_sizes",
        nargs="+",
        required=True,
        help="List of model:batch_size pairs (e.g., gpt_scales_0:1 resnet101:128)",
    )
    parser.add_argument(
        "--layout",
        choices=["row", "column"],
        default="row",
        help="Layout of subplots: row for horizontal, column for vertical",
    )
    parser.add_argument(
        "--cumulative",
        action="store_true",
        help="Show cumulative runtime instead of separate timings",
    )
    parser.add_argument(
        "--legend",
        action="store_true",
        help="Show legend",
    )
    parser.add_argument(
        "--show_x_label",
        action="store_true",
        help="Show x-axis label",
    )
    return parser.parse_args()


def main():
    # Parse command-line arguments
    args = parse_args()

    # Create output directory if it doesn't exist

    # Extract model:batch_size pairs
    model_batch_sizes = {}
    for pair in args.model_batch_sizes:
        if ":" not in pair:
            print(f"Warning: Skipping invalid pair '{pair}', format should be model:batch_size")
            continue
        model, batch_size = pair.split(":", 1)
        try:
            batch_size = int(batch_size)
            model_batch_sizes[model] = batch_size
        except ValueError:
            print(
                f"Warning: Skipping invalid batch size in '{pair}', batch size must be an integer"
            )

    if not model_batch_sizes:
        print("Error: No valid model:batch_size pairs provided")
        return

    layout = args.layout
    show_cumulative = args.cumulative
    show_legend = args.legend
    show_x_label = args.show_x_label

    # Load data
    df = load_model_data_with_batch_sizes(model_batch_sizes)

    # Generate plots
    plot_model_metrics(
        df,
        layout=layout,
        show_cumulative=show_cumulative,
        show_legend=show_legend,
        show_x_label=show_x_label,
    )


if __name__ == "__main__":
    main()
    main()
