import argparse
import json
import os


DEFAULT_CLOSED_MODELS = [
    "openai/chatgpt-4o-latest",
    "google/gemini-2.5-pro",
    "openai/o4-mini",
    "google/gemini-2.5-flash",
    "anthropic/claude-sonnet-4",
    "anthropic/claude-opus-4",
]

DEFAULT_OPEN_MODELS = [
    "qwen/qwen2.5-vl-72b-instruct",
    "meta-llama/llama-4-maverick",
    "mistralai/mistral-medium-3",
    "microsoft/phi-4-multimodal-instruct",
]


def format_model_name(model_name):
    return model_name.split("/")[-1]


def calculate_rankings(stats_data, models):
    model_scores = []

    for model in models:
        if model in stats_data["overall"]:
            score = stats_data["overall"][model]["pass_1"]
            model_scores.append((model, score))

    model_scores.sort(key=lambda x: x[1], reverse=True)

    rankings = {}
    for rank, (model, score) in enumerate(model_scores, 1):
        rankings[model] = rank

    return rankings


def find_best_in_category(stats_data, models, metric):
    best_value = -1
    for model in models:
        if model in stats_data["overall"]:
            value = stats_data["overall"][model][metric]
            if value > best_value:
                best_value = value
    return best_value


def format_value(value, is_best=False):
    formatted = f"{value * 100:.1f}"
    return f"\\textbf{{{formatted}}}" if is_best else formatted


def generate_latex_table(stats_data, closed_models, open_models, recaptcha_data=None):
    all_models = closed_models + open_models
    rankings = calculate_rankings(stats_data, all_models)

    closed_best_pass1 = find_best_in_category(stats_data, closed_models, "pass_1")
    closed_best_passk = find_best_in_category(stats_data, closed_models, "pass_k")
    closed_best_rel = find_best_in_category(stats_data, closed_models, "reliability")

    open_best_pass1 = find_best_in_category(stats_data, open_models, "pass_1")
    open_best_passk = find_best_in_category(stats_data, open_models, "pass_k")
    open_best_rel = find_best_in_category(stats_data, open_models, "reliability")

    closed_best_recaptcha_pass1 = -1
    closed_best_recaptcha_passk = -1
    closed_best_recaptcha_rel = -1
    open_best_recaptcha_pass1 = -1
    open_best_recaptcha_passk = -1
    open_best_recaptcha_rel = -1

    if recaptcha_data:
        closed_best_recaptcha_pass1 = find_best_in_category(
            recaptcha_data, closed_models, "pass_1"
        )
        closed_best_recaptcha_passk = find_best_in_category(
            recaptcha_data, closed_models, "pass_k"
        )
        closed_best_recaptcha_rel = find_best_in_category(
            recaptcha_data, closed_models, "reliability"
        )
        open_best_recaptcha_pass1 = find_best_in_category(
            recaptcha_data, open_models, "pass_1"
        )
        open_best_recaptcha_passk = find_best_in_category(
            recaptcha_data, open_models, "pass_k"
        )
        open_best_recaptcha_rel = find_best_in_category(
            recaptcha_data, open_models, "reliability"
        )

    lines = []

    lines.append("\\begin{tabular}{r|c|ccc|ccc}")
    lines.append(
        "  & & \\multicolumn{3}{c|}{\\textbf{Spatial Captcha}} & \\multicolumn{3}{c}{\\textbf{reCaptcha}} \\\\"
    )
    lines.append(
        " \\textbf{Methods} & \\textbf{Rank} & \\textbf{Pass@1} & \\textbf{Pass@k} & \\textbf{$\\frac{k}{k}$ reliability} & \\textbf{Pass@1} & \\textbf{Pass@k} & \\textbf{$\\frac{k}{k}$ reliability} \\\\"
    )
    lines.append(" \\midrule")
    lines.append("")
    lines.append(" \\rowcolor{catrow}")
    lines.append(" \\multicolumn{8}{l}{\\textit{Baseline}} \\\\")
    lines.append(
        " Chance level (Random) & -- & 25.0 & 57.0 & 1.5 & 0.2 & 0.6 & 0.0 \\\\"
    )
    lines.append(" \\midrule")
    lines.append("")
    lines.append(" \\rowcolor{catrow}")
    lines.append(" \\multicolumn{8}{l}{\\textit{Proprietary Models}} \\\\")

    for model in closed_models:
        if model not in stats_data["overall"]:
            continue

        data = stats_data["overall"][model]
        display_name = format_model_name(model)
        rank = rankings[model]

        pass1_val = data["pass_1"]
        passk_val = data["pass_k"]
        rel_val = data["reliability"]

        pass1_str = format_value(pass1_val, pass1_val == closed_best_pass1)
        passk_str = format_value(passk_val, passk_val == closed_best_passk)
        rel_str = format_value(rel_val, rel_val == closed_best_rel)

        recaptcha_pass1 = "0.0"
        recaptcha_passk = "0.0"
        recaptcha_rel = "0.0"
        if recaptcha_data and model in recaptcha_data["overall"]:
            recaptcha_pass1_val = recaptcha_data["overall"][model]["pass_1"]
            recaptcha_passk_val = recaptcha_data["overall"][model]["pass_k"]
            recaptcha_rel_val = recaptcha_data["overall"][model]["reliability"]

            recaptcha_pass1 = format_value(
                recaptcha_pass1_val, recaptcha_pass1_val == closed_best_recaptcha_pass1
            )
            recaptcha_passk = format_value(
                recaptcha_passk_val, recaptcha_passk_val == closed_best_recaptcha_passk
            )
            recaptcha_rel = format_value(
                recaptcha_rel_val, recaptcha_rel_val == closed_best_recaptcha_rel
            )

        lines.append(
            f" {display_name} & {rank} & {pass1_str} & {passk_str} & {rel_str} & {recaptcha_pass1} & {recaptcha_passk} & {recaptcha_rel} \\\\"
        )

    lines.append(" \\midrule")
    lines.append("")
    lines.append(" \\rowcolor{catrow}")
    lines.append(" \\multicolumn{8}{l}{\\textit{Open-weight Models}} \\\\")

    for model in open_models:
        if model not in stats_data["overall"]:
            continue

        data = stats_data["overall"][model]
        display_name = format_model_name(model)
        rank = rankings[model]

        pass1_val = data["pass_1"]
        passk_val = data["pass_k"]
        rel_val = data["reliability"]

        pass1_str = format_value(pass1_val, pass1_val == open_best_pass1)
        passk_str = format_value(passk_val, passk_val == open_best_passk)
        rel_str = format_value(rel_val, rel_val == open_best_rel)

        recaptcha_pass1 = "0.0"
        recaptcha_passk = "0.0"
        recaptcha_rel = "0.0"
        if recaptcha_data and model in recaptcha_data["overall"]:
            recaptcha_pass1_val = recaptcha_data["overall"][model]["pass_1"]
            recaptcha_passk_val = recaptcha_data["overall"][model]["pass_k"]
            recaptcha_rel_val = recaptcha_data["overall"][model]["reliability"]

            recaptcha_pass1 = format_value(
                recaptcha_pass1_val, recaptcha_pass1_val == open_best_recaptcha_pass1
            )
            recaptcha_passk = format_value(
                recaptcha_passk_val, recaptcha_passk_val == open_best_recaptcha_passk
            )
            recaptcha_rel = format_value(
                recaptcha_rel_val, recaptcha_rel_val == open_best_recaptcha_rel
            )

        lines.append(
            f" {display_name} & {rank} & {pass1_str} & {passk_str} & {rel_str} & {recaptcha_pass1} & {recaptcha_passk} & {recaptcha_rel} \\\\"
        )

    lines.append(" \\bottomrule")
    lines.append(" \\end{tabular}")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Generate LaTeX stats table")
    parser.add_argument(
        "--stats-json", type=str, default="stats.json", help="Path to stats JSON file"
    )
    parser.add_argument(
        "--recaptcha-json", type=str, help="Path to recaptcha stats JSON file"
    )
    parser.add_argument(
        "--closed-models", type=str, help="Comma-separated list of closed models"
    )
    parser.add_argument(
        "--open-models", type=str, help="Comma-separated list of open models"
    )

    args = parser.parse_args()

    if not os.path.exists(args.stats_json):
        raise FileNotFoundError(f"Stats file not found: {args.stats_json}")

    with open(args.stats_json, "r") as f:
        stats_data = json.load(f)

    recaptcha_data = None
    if args.recaptcha_json:
        if not os.path.exists(args.recaptcha_json):
            raise FileNotFoundError(
                f"Recaptcha stats file not found: {args.recaptcha_json}"
            )
        with open(args.recaptcha_json, "r") as f:
            recaptcha_data = json.load(f)

    closed_models = DEFAULT_CLOSED_MODELS
    if args.closed_models:
        closed_models = [model.strip() for model in args.closed_models.split(",")]

    open_models = DEFAULT_OPEN_MODELS
    if args.open_models:
        open_models = [model.strip() for model in args.open_models.split(",")]

    latex_table = generate_latex_table(
        stats_data, closed_models, open_models, recaptcha_data
    )
    print(latex_table)


if __name__ == "__main__":
    main()
