import argparse
import pathlib
from typing import Dict, TypedDict

import matplotlib.pyplot as plt
import pandas as pd
import yaml
from matplotlib.axes import Axes
from matplotlib.figure import Figure

SUCCESS_RATE_COLUMN = "success_rate_at_cost"
WEIGHTED_SUCCESS_RATE_COLUMN = "weighted_success_rate_at_cost"
WEIGHT_COL_NAMES = {
    "equal_task_weight": "Weighting Tasks Equally",
    "invsqrt_task_weight": "Weighted by Task Diversity",
}
FIGSIZE = (10, 6)
TITLE = {
    "generation_cost": "Agent Performance on HCAST & RE-Bench by Cost",
    "action_count": "Agent Performance on HCAST & RE-Bench by Actions",
}
YLABEL = "Weighted Success Rate"
XLABEL = {
    "generation_cost": "Allowed Cost per Run ($)",
    "action_count": "Allowed Actions per Run",
}


def _plot_agent_line(
    df_agent: pd.DataFrame,
    ax: Axes,
    usage_column: str,
    color: str = "black",
    linestyle: str = "-",
    label: str | None = None,
) -> None:
    # If multiple with same usage, keep the last one
    assert not df_agent.empty
    # Add a row at the start with 0s
    graph_df = df_agent.copy()
    graph_df = graph_df.drop_duplicates(subset=usage_column, keep="last")
    graph_df = graph_df[graph_df["task_source"] != "SWAA"]
    graph_df = graph_df.sort_values(by=usage_column)
    if df_agent["alias"].unique()[0] == "human":
        print(df_agent.columns)
        print(df_agent[usage_column])
        print(df_agent[WEIGHTED_SUCCESS_RATE_COLUMN])
    ax.plot(
        graph_df[usage_column],
        graph_df[WEIGHTED_SUCCESS_RATE_COLUMN],
        linewidth=1.75,
        color=color,
        alpha=0.95,
        label=label,
        drawstyle="steps-post",
        dash_capstyle="round",
        linestyle=linestyle,
    )


class AgentStyling(TypedDict):
    unique_color: str
    marker: str
    lab_color: str


def _score_line_chart(
    df: pd.DataFrame,
    agents: list[str],
    agent_styling: Dict[str, AgentStyling],
    weighting_column: str,
    usage_column: str,
    score_column: str,
) -> Figure:
    fig, ax = plt.subplots(figsize=(FIGSIZE[0], FIGSIZE[1]))

    # First, create a DataFrame with the last entry for each agent
    last_entries = df.groupby("alias").last().reset_index()

    # Sort the agents based on the value of the last entry
    agent_order = last_entries.sort_values(by=score_column, ascending=False)["alias"]

    for agent in agent_order:
        df_agent = df[df["alias"] == agent]
        _plot_agent_line(
            df_agent=df_agent,
            ax=ax,
            usage_column=usage_column,
            label=str(agent),
            color=agent_styling[agent]["unique_color"],
        )

    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
    ax.grid(True, linestyle="--", alpha=0.7)
    plt.subplots_adjust(right=0.5, bottom=0.2)
    ax.set_ylabel(YLABEL, wrap=True)
    ax.set_title(f"{TITLE[usage_column]}\n({WEIGHT_COL_NAMES[weighting_column]})")
    ax.set_ylim(0, 1)
    ax.set_xlabel(XLABEL[usage_column])
    ax.set_xscale("log")

    all_agents = [agent for agent_group in agents for agent in agent_group]
    df_filtered = df[df["alias"].isin(all_agents)]

    if not df_filtered.empty:
        min_usage, max_usage = df_filtered[usage_column].agg(["min", "max"])
        ax.set_xlim(0.5 * min_usage, 10 * max_usage)

    fig.tight_layout()
    return fig


def main(
    wrangled_resource_file: pathlib.Path,
    output_file: pathlib.Path,
    params_file: pathlib.Path,
    weighting_column: str,
    score_column: str,
    x: str,
) -> None:
    df = pd.read_json(wrangled_resource_file, lines=True, orient="records")
    params = yaml.safe_load(open(params_file))

    fig = _score_line_chart(
        df,
        agents=[df["alias"].unique().tolist()],
        agent_styling=params["plots"]["agent_styling"],
        usage_column=x,
        score_column=score_column,
        weighting_column=weighting_column,
    )

    output_file.parent.mkdir(parents=True, exist_ok=True)
    print(f"Saving figure to {output_file}")
    fig.savefig(output_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--wrangled-resource-file",
        type=pathlib.Path,
        required=True,
        help="Path to the wrangled resource data file (jsonl format)",
    )
    parser.add_argument(
        "--output-file",
        type=pathlib.Path,
        help="Path to save the output figure",
        required=True,
    )
    parser.add_argument(
        "--params-file",
        type=pathlib.Path,
        help="Path to the params file",
        required=True,
    )
    parser.add_argument(
        "--score-column",
        type=str,
        required=True,
        help="Column to use for scoring. E.g. score_binarized",
    )
    parser.add_argument(
        "--weighting-column",
        type=str,
        required=True,
        help="Column to use for weighting (either equal_task_weight or invsqrt_task_weight)",
    )
    parser.add_argument(
        "--x",
        required=True,
        type=str,
        help="Column to use for x-axis. E.g. generation_cost or action_count",
    )
    args = parser.parse_args()
    main(**vars(args))
