import os
import json
import argparse
from typing import Dict, List, Tuple

import pandas as pd


def calculate_model_averages(stats_data: Dict) -> Dict[str, Dict[str, float]]:
    model_averages = {}

    for model in stats_data["overall"]:
        overall_data = stats_data["overall"][model]
        model_averages[model] = {
            "pass_k": overall_data["pass_k"],
            "reliability": overall_data["reliability"]
        }

    return model_averages


def get_model_data(stats_data: Dict) -> Dict[str, Dict[str, float]]:
    model_data = {}

    if "overall" in stats_data:
        for model in stats_data["overall"]:
            overall_data = stats_data["overall"][model]
            model_data[model] = {
                "pass_k": overall_data["pass_k"],
                "reliability": overall_data["reliability"]
            }

    return model_data


def generate_tikz_scatter(model_data: Dict[str, Dict[str, float]]) -> str:
    model_marks = [
        "mark=*", "mark=square*", "mark=triangle*", "mark=diamond*",
        "mark=pentagon*", "mark=star", "mark=otimes*", "mark=+",
        "mark=x", "mark=o"
    ]

    model_colors = [
        "myred", "myorange", "mybrown", "mygreen", "myteal",
        "myblue", "mypurple", "mycyan"
    ]

    def format_model_name(model_name: str) -> str:
        return model_name.split("/")[-1] if "/" in model_name else model_name

    lines = []
    lines.append("\\begin{tikzpicture}")
    lines.append("\\begin{axis}[")
    lines.append("    width=\\linewidth,")
    lines.append("    height=\\linewidth,")
    lines.append("    xlabel={\\small Pass@k Coverage},")
    lines.append("    ylabel={\\small $k$-of-$k$ Reliability},")
    lines.append("    xmin=0, xmax=1.05,")
    lines.append("    ymin=0, ymax=1.05,")
    lines.append("    axis lines=left,")
    lines.append("    tick style={black, thick},")
    lines.append("    tick label style={font=\\scriptsize},")
    lines.append("    axis line style={thick, -{Latex[length=3pt]}},")
    lines.append("    xtick={0.2,0.4,0.6,0.8,1.0},")
    lines.append("    ytick={0.2,0.4,0.6,0.8,1.0},")
    lines.append("    enlargelimits=false,")
    lines.append("    grid=both,")
    lines.append("    title={\\textbf{Pass@k vs. Reliability}},")
    lines.append("    title style={yshift=-1.2em, font=\\small},")
    lines.append("    legend style={")
    lines.append("        at={(0.98,0.98)},")
    lines.append("        anchor=north east,")
    lines.append("        font=\\tiny,")
    lines.append("        fill=white,")
    lines.append("        draw=black,")
    lines.append("        opacity=0.9")
    lines.append("    },")
    lines.append("]")
    lines.append("")

    color_definitions = [
        "\\definecolor{myred}{RGB}{230,97,1}",
        "\\definecolor{myorange}{RGB}{253,184,99}",
        "\\definecolor{mybrown}{RGB}{178,171,210}",
        "\\definecolor{mygreen}{RGB}{120,198,121}",
        "\\definecolor{myteal}{RGB}{27,158,119}",
        "\\definecolor{myblue}{RGB}{66,146,198}",
        "\\definecolor{mypurple}{RGB}{158,154,200}",
        "\\definecolor{mycyan}{RGB}{102,194,165}",
    ]
    for color_def in color_definitions:
        lines.append(color_def)
    lines.append("")

    lines.append("% --- Colored background regions")
    lines.append(
        "\\addplot [draw=none, fill=red!5, forget plot] coordinates {(0,0) (1.05,0) (1.05,1.05)};")
    lines.append(
        "\\addplot [draw=none, fill=green!10, forget plot] coordinates {(0,0) (0,1.05) (1.05,1.05)};")
    lines.append("")

    lines.append("% --- Diagonal y=x line")
    lines.append("\\addplot [domain=0:1.05, dashed, thick, black, forget plot] {x};")
    lines.append("")

    lines.append("% --- Region Labels")
    lines.append(
        "\\node[font=\\scriptsize, align=left, text width=3.2cm, anchor=south east, text=black] at (axis cs:1.00,0.05) {\\textbf{Overconfident}\\\\ High coverage, low reliability};")
    lines.append(
        "\\node[font=\\scriptsize, align=left, text width=3.2cm, anchor=north west, text=black] at (axis cs:0.05,1.00) {\\textbf{Well-calibrated}\\\\ Reliability $\\geq$ Coverage};")
    lines.append("")

    sorted_models = sorted(model_data.keys())

    for i, (model, data) in enumerate(sorted(model_data.items())):
        mark = model_marks[i % len(model_marks)]
        color = model_colors[i % len(model_colors)]
        formatted_name = format_model_name(model)
        pass_k = data["pass_k"]
        reliability = data["reliability"]

        lines.append("\\addplot[")
        lines.append(f"    {mark},")
        lines.append(f"    draw=black,")
        lines.append(f"    fill={color},")
        lines.append("    mark size=2.8pt,")
        lines.append("    only marks")
        lines.append(f"] coordinates {{({pass_k:.3f}, {reliability:.3f})}};")
        lines.append("")

    for i, (model, data) in enumerate(sorted(model_data.items())):
        mark = model_marks[i % len(model_marks)]
        color = model_colors[i % len(model_colors)]
        formatted_name = format_model_name(model)

        lines.append(
            f"\\addlegendimage{{{mark}, draw=black, fill={color}, mark size=2.8pt}}")
        lines.append(f"\\addlegendentry{{{formatted_name}}}")
        lines.append("")

    lines.append("\\end{axis}")
    lines.append("\\end{tikzpicture}")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(
        description="Generate TikZ scatter plot for Pass@k vs Reliability by scene"
    )
    parser.add_argument(
        "--stats-json",
        type=str,
        required=True,
        help="Path to stats JSON file",
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Output file path (if not specified, prints to stdout)",
    )

    args = parser.parse_args()

    if not os.path.exists(args.stats_json):
        raise FileNotFoundError(f"Stats file not found: {args.stats_json}")

    with open(args.stats_json, "r") as f:
        stats_data = json.load(f)

    model_data = get_model_data(stats_data)

    if not model_data:
        print("No model data found in stats file.")
        return

    tikz_code = generate_tikz_scatter(model_data)

    if args.output:
        with open(args.output, "w") as f:
            f.write(tikz_code)
        print(f"TikZ code saved to {args.output}")
    else:
        print(tikz_code)


if __name__ == "__main__":
    main()
