import argparse
import os

import git
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D

MODEL_NAMES = {
    "resnet18": "ResNet-18",
    "resnet34": "ResNet-34",
    "resnet50": "ResNet-50",
    "resnet101": "ResNet-101",
    "resnet152": "ResNet-152",
    "vit_b_16": "ViT-B/16",
}


def load_csv_data(model_name, batch_size):
    """Load data from CSV file for a specific model and batch size."""
    repo = git.Repo(".", search_parent_directories=True)
    base_path = os.path.join(repo.working_tree_dir, "output", "model_perf", "csv_data")
    csv_path = os.path.join(base_path, f"{model_name}_bs{batch_size}.csv")

    if not os.path.exists(csv_path):
        print(
            f"Warning: CSV file not found for {model_name} with batch size {batch_size}: {csv_path}"
        )
        return None

    return pd.read_csv(csv_path)


def plot_model_metrics(
    ax,
    ax2,
    data,
    model_name,
    batch_size,
    is_first_column=False,
    is_first_row=False,
    is_last_row=False,
    is_last_column=False,
):
    """Plot all metrics for a specific model and batch size on a single subplot."""
    if data is None:
        ax.text(
            0.5,
            0.5,
            f"No data for\n{model_name}\nBS={batch_size}",
            ha="center",
            va="center",
            transform=ax.transAxes,
        )
        return

    # Plot forward and backward times on left y-axis
    ax.plot(
        data["frac_trainable_params"],
        data["median_fwd_time"],
        marker="o",
        markersize=2,
        linestyle="-",
        color="blue",
        label="Forward Time",
    )
    ax.plot(
        data["frac_trainable_params"],
        data["median_bwd_time"],
        marker="s",
        markersize=2,
        linestyle="-",
        color="red",
        label="Backward Time",
    )
    ax.plot(
        data["frac_trainable_params"],
        data["optimizer_time"],
        marker="^",
        markersize=2,
        linestyle="-",
        color="green",
        label="Optimizer Time",
    )

    # Plot memory on right y-axis
    ax2.plot(
        data["frac_trainable_params"],
        data["memory"],
        marker="d",
        markersize=2,
        linestyle="--",
        color="purple",
        label="Memory (MB)",
    )

    # Set labels and title
    if is_first_row:
        pass
        # ax.set_title(f"Batch size: {batch_size}", fontsize=12)
    if is_last_row:
        ax.set_xlabel(r"% Trainable Parameters")
        ax.set_xticks([0, 0.5, 1], minor=True)
    else:
        ax.set_xlabel("")
        ax.tick_params(axis="x", which="both", length=0)

        ax.set_xticks([], minor=True)
        ax.set_xticklabels([])

    # Only add y-axis labels for the first column
    if is_first_column:
        ax.set_ylabel(f"{MODEL_NAMES[model_name]}\nTime (ms)")
        # Only show y-ticks for the first column
        ax.tick_params(axis="y", which="both", labelleft=True)
    else:
        # Hide y-ticks for all other columns
        ax.tick_params(axis="y", which="both", labelleft=False)

    if is_last_column:
        ax2.set_ylabel("Memory (MB)")
        # Only show y-ticks for the last column
        ax2.tick_params(axis="y", which="both", labelright=True)
    else:
        # Hide y-ticks for all other columns
        ax2.tick_params(axis="y", which="both", labelright=False)

    # Add grid
    ax.grid(True, linestyle="--", alpha=0.7)


def plot_grid(models, batch_sizes, save_path=None):
    """Create a grid of plots where rows are models and columns are batch sizes."""
    # Create figure
    n_rows = len(models)
    n_cols = len(batch_sizes)

    fig = plt.figure(figsize=(n_cols * 3, n_rows * 1.7))  # Add extra space for the legend

    # Create a grid of subplots
    axes = {}
    axes2 = {}  # For the twin axes (memory)

    # First, create all the subplots with shared y-axes per row
    for row_idx, _ in enumerate(models):
        # Create the first subplot in the row
        ax = fig.add_subplot(n_rows, n_cols, row_idx * n_cols + 1)
        ax2 = ax.twinx()
        axes[(row_idx, 0)] = ax
        axes2[(row_idx, 0)] = ax2

        # Create the rest of the subplots in the row with shared y-axes
        for col_idx in range(1, n_cols):
            ax_shared = fig.add_subplot(
                n_rows, n_cols, row_idx * n_cols + col_idx + 1, sharey=axes[(row_idx, 0)]
            )
            ax2_shared = ax_shared.twinx()
            # Share y-axis with the first twin axis in the row
            ax2_shared.sharey(axes2[(row_idx, 0)])

            axes[(row_idx, col_idx)] = ax_shared
            axes2[(row_idx, col_idx)] = ax2_shared

    # Now plot the data
    for row_idx, model_name in enumerate(models):
        for col_idx, batch_size in enumerate(batch_sizes):
            ax = axes[(row_idx, col_idx)]
            ax2 = axes2[(row_idx, col_idx)]

            # Load data for this model and batch size
            data = load_csv_data(model_name, batch_size)

            # Plot metrics on the subplot
            plot_model_metrics(
                ax,
                ax2,
                data,
                model_name,
                batch_size,
                is_first_column=(col_idx == 0),
                is_first_row=(row_idx == 0),
                is_last_row=(row_idx == len(models) - 1),
                is_last_column=(col_idx == len(batch_sizes) - 1),
            )

    # Add overall title
    # plt.suptitle("Hardware Metrics Across Models and Batch Sizes", fontsize=16, y=0.995)

    # Create a shared legend at the bottom of the figure
    legend_elements = [
        Line2D(
            [0], [0], marker="o", color="blue", label="Forward pass", markersize=4, linestyle="-"
        ),
        Line2D(
            [0], [0], marker="s", color="red", label="Backward pass", markersize=4, linestyle="-"
        ),
        Line2D(
            [0], [0], marker="^", color="green", label="Optimizer step", markersize=4, linestyle="-"
        ),
        Line2D([0], [0], marker="d", color="purple", label="Memory", markersize=4, linestyle="--"),
    ]

    # Place the legend at the bottom center of the figure
    fig.legend(
        handles=legend_elements,
        loc="lower center",
        ncol=1,
        bbox_to_anchor=(0.5, -0.14),
        fontsize=8,
        handlelength=1,
        # frameon=True,
        # fancybox=True,
        # shadow=True,
    )

    # Adjust layout
    # plt.tight_layout(rect=[0, 0.05, 1, 0.97])  # Leave space at the bottom for the legend
    plt.tight_layout()

    # Save figure if path is provided
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=500, bbox_inches="tight")
        print(f"Plot saved to {save_path}")

    return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "vit_b_16"],
        help="List of models to plot",
    )
    parser.add_argument(
        "--batch_sizes",
        type=int,
        nargs="+",
        default=[1, 2, 4, 8, 16, 32, 64],
        help="List of batch sizes to plot",
    )

    args = parser.parse_args()

    # Create output directory
    repo = git.Repo(".", search_parent_directories=True)
    output_dir = os.path.join(repo.working_tree_dir, "output", "model_perf", "plots")
    os.makedirs(output_dir, exist_ok=True)

    # Generate plots based on the specified type
    save_path = os.path.join(output_dir, "hw_metrics_grid.pdf")
    plot_grid(args.models, args.batch_sizes, save_path=save_path)

    print(f"Plots generated in {output_dir}")
