import argparse
import json
import os
import pandas as pd


DEFAULT_CLOSED_MODELS = [
    "openai/chatgpt-4o-latest",
    "google/gemini-2.5-pro",
    "openai/o4-mini",
    "google/gemini-2.5-flash",
    "anthropic/claude-sonnet-4",
    "anthropic/claude-opus-4",
]

DEFAULT_OPEN_MODELS = [
    "qwen/qwen2.5-vl-72b-instruct",
    "meta-llama/llama-4-maverick",
    "mistralai/mistral-medium-3",
    "microsoft/phi-4-multimodal-instruct",
]

DEFAULT_CATEGORY_MAPPING = {
    "Spatial perception": ["Pyramid"],
    "Spatial orientation": ["Agent Sight", "Sun Direction"],
    "Mental objects rotation": ["Revolution", "Unfolded", "Pyramid"],
    "Spatial visualization": ["Polyomino", "Full Views"],
}


def format_model_name(model_name):
    return model_name.split("/")[-1]


def calculate_rankings(stats_data, models):
    model_scores = []

    for model in models:
        if model in stats_data["overall"]:
            score = stats_data["overall"][model]["pass_1"]
            model_scores.append((model, score))

    model_scores.sort(key=lambda x: x[1], reverse=True)

    rankings = {}
    for rank, (model, score) in enumerate(model_scores, 1):
        rankings[model] = rank

    return rankings


def find_best_in_category(stats_data, models, metric):
    best_value = -1
    for model in models:
        if model in stats_data["overall"]:
            value = stats_data["overall"][model][metric]
            if value > best_value:
                best_value = value
    return best_value


def find_second_best_in_category(stats_data, models, metric):
    values = []
    for model in models:
        if model in stats_data["overall"]:
            value = stats_data["overall"][model][metric]
            values.append(value)
    values.sort(reverse=True)
    return values[1] if len(values) > 1 else -1


def format_value(value, is_best=False, is_second_best=False):
    formatted = f"{value * 100:.1f}"
    if is_best:
        return f"\\cellcolor{{darkgray}}{formatted}"
    elif is_second_best:
        return f"\\cellcolor{{lightgray}}{formatted}"
    else:
        return formatted


def format_rank(rank, rank_position):
    if rank_position == 1:
        return f"\\cellcolor{{highlight}}{rank}"
    elif rank_position == 2:
        return f"\\cellcolor{{highlight!50}}{rank}"
    elif rank_position == 3:
        return f"\\cellcolor{{highlight!20}}{rank}"
    else:
        return str(rank)


def calculate_category_average(stats_data, model, category_scenes):
    total = 0
    count = 0
    for scene in category_scenes:
        if scene in stats_data and model in stats_data[scene]:
            total += stats_data[scene][model]["pass_1"]
            count += 1
    return total / count if count > 0 else 0


def find_best_in_category_avg(stats_data, models, category_scenes):
    best_value = -1
    for model in models:
        avg = calculate_category_average(stats_data, model, category_scenes)
        if avg > best_value:
            best_value = avg
    return best_value


def find_second_best_in_category_avg(stats_data, models, category_scenes):
    values = []
    for model in models:
        avg = calculate_category_average(stats_data, model, category_scenes)
        values.append(avg)
    values.sort(reverse=True)
    return values[1] if len(values) > 1 else -1


def read_manifest(scenes_dir, scene_name):
    scene_folder = scene_name.lower().replace(" ", "_")
    manifest_path = os.path.join(scenes_dir, scene_folder, "manifest.json")
    if not os.path.exists(manifest_path):
        print(
            f"Error: Manifest not found for scene '{scene_name}' at path: {manifest_path}"
        )
        exit(1)
    with open(manifest_path, "r") as f:
        return json.load(f)


def calculate_random_chances(scenes_dir):
    scene_stats = {}
    k = 3

    for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
        for scene in scenes:
            manifest = read_manifest(scenes_dir, scene)
            if manifest and "task" in manifest and "answer" in manifest["task"]:
                num_variants = manifest["task"]["answer"].get("num_variants", 4)
            else:
                num_variants = 4

            pass1 = 1.0 / num_variants
            passk = 1 - (1 - pass1) ** k
            reliability = pass1**k

            scene_stats[scene] = {
                "pass_1": pass1,
                "pass_k": passk,
                "reliability": reliability,
            }

    all_scenes = list(scene_stats.keys())
    overall_pass1 = sum(scene_stats[scene]["pass_1"] for scene in all_scenes) / len(
        all_scenes
    )
    overall_passk = sum(scene_stats[scene]["pass_k"] for scene in all_scenes) / len(
        all_scenes
    )
    overall_rel = sum(scene_stats[scene]["reliability"] for scene in all_scenes) / len(
        all_scenes
    )

    category_stats = {}
    for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
        category_scenes = [scene for scene in scenes if scene in scene_stats]
        if category_scenes:
            cat_pass1 = sum(
                scene_stats[scene]["pass_1"] for scene in category_scenes
            ) / len(category_scenes)
            cat_passk = sum(
                scene_stats[scene]["pass_k"] for scene in category_scenes
            ) / len(category_scenes)
            cat_rel = sum(
                scene_stats[scene]["reliability"] for scene in category_scenes
            ) / len(category_scenes)
            category_stats[category] = {
                "pass_1": cat_pass1,
                "pass_k": cat_passk,
                "reliability": cat_rel,
            }
        else:
            category_stats[category] = {
                "pass_1": 0.25,
                "pass_k": 0.57,
                "reliability": 0.015,
            }

    return {
        "pass_1": overall_pass1,
        "pass_k": overall_passk,
        "reliability": overall_rel,
    }, category_stats


def read_human_results(human_csv_path):
    df = pd.read_csv(human_csv_path)

    all_scenes = set()
    for scenes in DEFAULT_CATEGORY_MAPPING.values():
        all_scenes.update(scenes)

    human_stats = {}
    overall_scores = []

    for scene_name in all_scenes:
        csv_col = scene_name.lower().replace(" ", "_") + "_score"
        if csv_col not in df.columns:
            print(
                f"Error: Column '{csv_col}' not found in CSV for scene '{scene_name}'"
            )
            exit(1)

        valid_scores = df[df[csv_col] != -1][csv_col]
        if len(valid_scores) > 0:
            avg_score = valid_scores.mean() / 100.0
            human_stats[scene_name] = avg_score
            overall_scores.append(avg_score)

    overall_avg = sum(overall_scores) / len(overall_scores) if overall_scores else 0

    category_averages = {}
    for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
        category_scores = [
            human_stats[scene] for scene in scenes if scene in human_stats
        ]
        if category_scores:
            category_averages[category] = sum(category_scores) / len(category_scores)
        else:
            category_averages[category] = 0

    return overall_avg, category_averages


def generate_latex_table(
    stats_data, closed_models, open_models, scenes_dir=None, human_results_path=None
):
    all_models = closed_models + open_models
    rankings = calculate_rankings(stats_data, all_models)

    global_best_pass1 = find_best_in_category(stats_data, all_models, "pass_1")
    global_best_passk = find_best_in_category(stats_data, all_models, "pass_k")
    global_best_rel = find_best_in_category(stats_data, all_models, "reliability")

    open_best_pass1 = find_best_in_category(stats_data, open_models, "pass_1")
    open_best_passk = find_best_in_category(stats_data, open_models, "pass_k")
    open_best_rel = find_best_in_category(stats_data, open_models, "reliability")

    global_best_categories = {}
    open_best_categories = {}

    for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
        global_best_categories[category] = find_best_in_category_avg(
            stats_data, all_models, scenes
        )
        open_best_categories[category] = find_best_in_category_avg(
            stats_data, open_models, scenes
        )

    if scenes_dir:
        overall_stats, category_stats = calculate_random_chances(scenes_dir)
    else:
        overall_stats = {"pass_1": 0.25, "pass_k": 0.57, "reliability": 0.015}
        category_stats = {
            category: {"pass_1": 0.25, "pass_k": 0.57, "reliability": 0.015}
            for category in DEFAULT_CATEGORY_MAPPING.keys()
        }

    lines = []

    lines.append("\\begin{tabular}{r|c|ccc|cccc}")
    lines.append(
        "  & & Pass@1 & pass@k & k-of-k & SP & SO & MOR & SV \\\\"
    )
    lines.append(
        " \\textbf{Methods} & \\textbf{Rank} & \\multicolumn{3}{c|}{\\cellcolor{yellow!10}\\textbf{Overall Metrics}} & \\multicolumn{4}{c}{\\cellcolor{orange!10}\\textbf{Per-Ability Pass@1}} \\\\"
    )
    lines.append(" \\midrule")
    lines.append("")
    lines.append(" \\rowcolor{catrow}")
    lines.append(" \\multicolumn{9}{l}{\\textit{Baseline}} \\\\")

    category_chance_values = []
    for category in [
        "Spatial perception",
        "Spatial orientation",
        "Mental objects rotation",
        "Spatial visualization",
    ]:
        category_chance_values.append(f"{category_stats[category]['pass_1'] * 100:.1f}")

    lines.append(
        f" Chance level (Random) & -- & {overall_stats['pass_1'] * 100:.1f} & {overall_stats['pass_k'] * 100:.1f} & {overall_stats['reliability'] * 100:.1f} & {' & '.join(category_chance_values)} \\\\"
    )

    if human_results_path:
        human_overall, human_categories = read_human_results(human_results_path)
        lines.append("")
        lines.append(" \\rowcolor{catrow}")
        lines.append(
            " \\multicolumn{9}{l}{\\textit{\\textsc{Spatial-CAPTCHA-Bench} (tiny)}} \\\\"
        )

        human_category_values = []
        for category in [
            "Spatial perception",
            "Spatial orientation",
            "Mental objects rotation",
            "Spatial visualization",
        ]:
            human_category_values.append(f"{human_categories[category] * 100:.1f}")

        lines.append(
            f" Human Level & -- & {human_overall * 100:.1f} & -- & -- & {' & '.join(human_category_values)} \\\\"
        )

    lines.append(" \\midrule")
    lines.append("")
    lines.append(" \\rowcolor{catrow}")
    lines.append(" \\multicolumn{9}{l}{\\textit{Proprietary Models}} \\\\")

    closed_ranks = sorted(
        [rankings[model] for model in closed_models if model in stats_data["overall"]]
    )

    for model in closed_models:
        if model not in stats_data["overall"]:
            continue

        data = stats_data["overall"][model]
        display_name = format_model_name(model)
        rank = rankings[model]
        rank_position = closed_ranks.index(rank) + 1

        pass1_val = data["pass_1"]
        passk_val = data["pass_k"]
        rel_val = data["reliability"]

        is_global_best_pass1 = pass1_val == global_best_pass1
        is_open_best_pass1 = pass1_val == open_best_pass1 and not is_global_best_pass1
        pass1_str = format_value(pass1_val, is_global_best_pass1, is_open_best_pass1)

        is_global_best_passk = passk_val == global_best_passk
        is_open_best_passk = passk_val == open_best_passk and not is_global_best_passk
        passk_str = format_value(passk_val, is_global_best_passk, is_open_best_passk)

        is_global_best_rel = rel_val == global_best_rel
        is_open_best_rel = rel_val == open_best_rel and not is_global_best_rel
        rel_str = format_value(rel_val, is_global_best_rel, is_open_best_rel)

        rank_str = format_rank(rank, rank_position)

        category_values = []
        for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
            avg_val = calculate_category_average(stats_data, model, scenes)
            is_global_best = avg_val == global_best_categories[category]
            is_open_best = (
                avg_val == open_best_categories[category] and not is_global_best
            )
            category_values.append(format_value(avg_val, is_global_best, is_open_best))

        lines.append(
            f" {display_name} & {rank_str} & {pass1_str} & {passk_str} & {rel_str} & {' & '.join(category_values)} \\\\"
        )

    lines.append(" \\midrule")
    lines.append("")
    lines.append(" \\rowcolor{catrow}")
    lines.append(" \\multicolumn{9}{l}{\\textit{Open-weight Models}} \\\\")

    open_ranks = sorted(
        [rankings[model] for model in open_models if model in stats_data["overall"]]
    )

    for model in open_models:
        if model not in stats_data["overall"]:
            continue

        data = stats_data["overall"][model]
        display_name = format_model_name(model)
        rank = rankings[model]
        rank_position = open_ranks.index(rank) + 1

        pass1_val = data["pass_1"]
        passk_val = data["pass_k"]
        rel_val = data["reliability"]

        is_global_best_pass1 = pass1_val == global_best_pass1
        is_open_best_pass1 = pass1_val == open_best_pass1 and not is_global_best_pass1
        pass1_str = format_value(pass1_val, is_global_best_pass1, is_open_best_pass1)

        is_global_best_passk = passk_val == global_best_passk
        is_open_best_passk = passk_val == open_best_passk and not is_global_best_passk
        passk_str = format_value(passk_val, is_global_best_passk, is_open_best_passk)

        is_global_best_rel = rel_val == global_best_rel
        is_open_best_rel = rel_val == open_best_rel and not is_global_best_rel
        rel_str = format_value(rel_val, is_global_best_rel, is_open_best_rel)

        rank_str = format_rank(rank, rank_position)

        category_values = []
        for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
            avg_val = calculate_category_average(stats_data, model, scenes)
            is_global_best = avg_val == global_best_categories[category]
            is_open_best = (
                avg_val == open_best_categories[category] and not is_global_best
            )
            category_values.append(format_value(avg_val, is_global_best, is_open_best))

        lines.append(
            f" {display_name} & {rank_str} & {pass1_str} & {passk_str} & {rel_str} & {' & '.join(category_values)} \\\\"
        )

    lines.append(" \\bottomrule")
    lines.append(" \\end{tabular}")

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Generate LaTeX stats table")
    parser.add_argument(
        "--stats-json", type=str, default="stats.json", help="Path to stats JSON file"
    )
    parser.add_argument(
        "--closed-models", type=str, help="Comma-separated list of closed models"
    )
    parser.add_argument(
        "--open-models", type=str, help="Comma-separated list of open models"
    )
    parser.add_argument(
        "--scenes-dir", type=str, help="Path to scenes directory containing manifests"
    )
    parser.add_argument(
        "--human-results", type=str, help="Path to human results CSV file"
    )

    args = parser.parse_args()

    if not os.path.exists(args.stats_json):
        raise FileNotFoundError(f"Stats file not found: {args.stats_json}")

    with open(args.stats_json, "r") as f:
        stats_data = json.load(f)

    closed_models = DEFAULT_CLOSED_MODELS
    if args.closed_models:
        closed_models = [model.strip() for model in args.closed_models.split(",")]

    open_models = DEFAULT_OPEN_MODELS
    if args.open_models:
        open_models = [model.strip() for model in args.open_models.split(",")]

    latex_table = generate_latex_table(
        stats_data, closed_models, open_models, args.scenes_dir, args.human_results
    )
    print(latex_table)


if __name__ == "__main__":
    main()
