from pathlib import Path
import pickle

from ale_bench.data import RankingCalculator, RatingCalculator
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from tqdm import tqdm

import stats_utils


ROOT_DIR = Path(__file__).resolve().parents[1]
BASELINE_RESULTS_DIR = ROOT_DIR / "baselines" / "results"
OPENHANDS_RESULTS_DIRS = [d for d in (ROOT_DIR / "openhands").glob("experiments*")]

MODEL_ORDER = [
    "gpt-4o-mini", "gpt-4o", "gpt-4.1-nano", "gpt-4.1-mini", "gpt-4.1",
    "o1-high", "o3-mini-high", "o3-high", "o4-mini-high",
    "gemini-1.5-flash-8b", "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash-lite", "gemini-2.0-flash",
    "gemini-2.5-flash-thinking", "gemini-2.5-pro-thinking",
    "claude-3.5-haiku", "claude-3.5-sonnet", "claude-3.7-sonnet", "claude-3.7-sonnet-thinking",
    "deepseek-v3", "deepseek-r1",
]
MODEL_NAME_MAP = {
    "gpt-4o-mini": "GPT-4o mini", "gpt-4o": "GPT-4o",
    "gpt-4.1-nano": "GPT-4.1 nano", "gpt-4.1-mini": "GPT-4.1 mini", "gpt-4.1": "GPT-4.1",
    "o1-high": "o1-high", "o3-mini-high": "o3-mini-high", "o3-high": "o3-high", "o4-mini-high": "o4-mini-high",
    "gemini-1.5-flash-8b": "Gemini 1.5 Flash-8B", "gemini-1.5-flash": "Gemini 1.5 Flash", "gemini-1.5-pro": "Gemini 1.5 Pro",
    "gemini-2.0-flash-lite": "Gemini 2.0 Flash-Lite", "gemini-2.0-flash": "Gemini 2.0 Flash",
    "gemini-2.5-flash-thinking": "Gemini 2.5 Flash", "gemini-2.5-pro-thinking": "Gemini 2.5 Pro",
    "claude-3.5-haiku": "Claude 3.5 Haiku", "claude-3.5-sonnet": "Claude 3.5 Sonnet",
    "claude-3.7-sonnet": "Claude 3.7 Sonnet", "claude-3.7-sonnet-thinking": "Claude 3.7 Sonnet (Thinking)",
    "deepseek-v3": "DeepSeek-V3", "deepseek-r1": "DeepSeek-R1",
}
CODE_LANGUAGE_ORDER = ["cpp20", "python", "rust", "average"]
CODE_LANGUAGE_NAME_MAP = {"cpp20": "C++20", "python": "Python3", "rust": "Rust", "average": "Average"}

COLUMN_FORMATTER = {
    "model": lambda x: MODEL_NAME_MAP[x],
    "code_language": lambda x: CODE_LANGUAGE_NAME_MAP[x],
    "average_performance_short": lambda x: f"{round(x, 0):.0f}",
    "average_performance_long": lambda x: f"{round(x, 0):.0f}",
    "average_performance": lambda x: f"{round(x, 0):.0f}",
    "average_performance_rank": lambda x: f"{round(x, 0):.0f}",
    "average_performance_rank_percentile": lambda x: f"{round(x, 3):.1%}"[:-1],
    "rating": lambda x: f"{round(x, 0):.0f}",
    "rating_rank": lambda x: f"{round(x, 0):.0f}",
    "rating_rank_percentile": lambda x: f"{round(x, 3):.1%}"[:-1],
    "total_cost": lambda x: f"{round(x, 3):.3f}",
    "total_response_count": lambda x: f"{round(x, 1):.1f}",
    "cost_per_response": lambda x: f"{round(x, 3):.3f}",
    "brown": lambda x: f"{round(x, 3):.1%}"[:-1],
    "green": lambda x: f"{round(x, 3):.1%}"[:-1],
    "cyan": lambda x: f"{round(x, 3):.1%}"[:-1],
    "blue": lambda x: f"{round(x, 3):.1%}"[:-1],
    "yellow": lambda x: f"{round(x, 3):.1%}"[:-1],
    "orange": lambda x: f"{round(x, 3):.1%}"[:-1],
    "red": lambda x: f"{round(x, 3):.1%}"[:-1],
}


results = {
    "first_accept": stats_utils.collect_results_baseline(BASELINE_RESULTS_DIR / "first_accept"),
    "four_hours": stats_utils.collect_results_baseline(BASELINE_RESULTS_DIR / "four_hours"),
    "openhands": {model_name: model_results for model_name, model_results in [stats_utils.collect_results_openhands(d) for d in OPENHANDS_RESULTS_DIRS]},
}
fishylene_results = {}


dfs = {}
rating_calculator = RatingCalculator()
ranking_calculator = RankingCalculator()
num_actives = ranking_calculator.rating_ranks[-1] - 1
assert num_actives == 2220  # Check that the number of active problems is correct
for setting_name, setting_results in results.items():
    setting_data = []
    setting_lite_data = []
    for exp_name, exp_result in tqdm(setting_results.items(), desc=f"Processing {setting_name} results", total=len(setting_results)):
        performances = {}
        performances_short = {}
        performances_long = {}
        performances_lite = {}
        performances_short_lite = {}
        performances_long_lite = {}
        total_cost = 0
        total_cost_lite = 0
        total_response_count = 0
        total_response_count_lite = 0
        for problem_result in exp_result:
            performances[problem_result["problem_id"]] = problem_result["performance"]
            if problem_result["problem_id"] in stats_utils.LONG_PROBLEM_IDS:
                performances_long[problem_result["problem_id"]] = problem_result["performance"]
            else:
                performances_short[problem_result["problem_id"]] = problem_result["performance"]
            total_cost += problem_result["total_cost"]
            total_response_count += problem_result["response_count"]
            if problem_result["problem_id"] in stats_utils.PROBLEM_IDS_LITE:
                performance_lite = stats_utils.calculate_lite_performance(
                    problem_result["problem_id"], problem_result["private_result"]
                )
                performances_lite[problem_result["problem_id"]] = performance_lite
                if problem_result["problem_id"] in stats_utils.LONG_PROBLEM_IDS:
                    performances_long_lite[problem_result["problem_id"]] = performance_lite
                else:
                    performances_short_lite[problem_result["problem_id"]] = performance_lite
                total_cost_lite += problem_result["total_cost"]
                total_response_count_lite += problem_result["response_count"]
        for missed_problem_id in set(stats_utils.PROBLEM_IDS) - set(performances.keys()):
            print(f"Missing performance for {missed_problem_id} in {exp_name}")
            performances[missed_problem_id] = stats_utils.MINIMUM_PERFORMANCES[missed_problem_id]  # Fill with minimum performance
            if missed_problem_id in stats_utils.LONG_PROBLEM_IDS:
                performances_long[missed_problem_id] = stats_utils.MINIMUM_PERFORMANCES[missed_problem_id]
            else:
                performances_short[missed_problem_id] = stats_utils.MINIMUM_PERFORMANCES[missed_problem_id]
        for missed_problem_id in set(stats_utils.PROBLEM_IDS_LITE) - set(performances_lite.keys()):
            print(f"Missing performance (lite) for {missed_problem_id} in {exp_name}")
            performances_lite[missed_problem_id] = stats_utils.MINIMUM_PERFORMANCES[missed_problem_id]
            if missed_problem_id in stats_utils.LONG_PROBLEM_IDS:
                performances_long_lite[missed_problem_id] = stats_utils.MINIMUM_PERFORMANCES[missed_problem_id]
            else:
                performances_short_lite[missed_problem_id] = stats_utils.MINIMUM_PERFORMANCES[missed_problem_id]
        rating = rating_calculator.calculate_rating(performances, "ahc046")
        rating_lite = rating_calculator.calculate_rating(performances_lite, "ahc046")
        rating_rank = ranking_calculator.calculate_rating_rank(rating)
        rating_rank_lite = ranking_calculator.calculate_rating_rank(rating_lite)
        avg_perf = sum(performances.values()) / len(performances)
        avg_perf_lite = sum(performances_lite.values()) / len(performances_lite)
        avg_perf_rank = ranking_calculator.calculate_avg_perf_rank(avg_perf)
        avg_perf_rank_lite = ranking_calculator.calculate_avg_perf_rank(avg_perf_lite)
        model, code_lang = exp_name.rsplit("_", 1)
        setting_data.append({
            "experiment_name": exp_name,
            "model": model,
            "code_language": code_lang,
            "average_performance_short": sum(performances_short.values()) / len(performances_short),
            "average_performance_long": sum(performances_long.values()) / len(performances_long),
            "average_performance": avg_perf,
            "average_performance_rank": avg_perf_rank,
            "average_performance_rank_percentile": avg_perf_rank / num_actives,
            "rating": rating,
            "rating_rank": rating_rank,
            "rating_rank_percentile": rating_rank / num_actives,
            "total_cost": total_cost / len(performances),
            "total_response_count": total_response_count / len(performances),
            "cost_per_response": total_cost / total_response_count,
            "brown": len([p for p in performances.values() if p > 400]),
            "green": len([p for p in performances.values() if p > 800]),
            "cyan": len([p for p in performances.values() if p > 1200]),
            "blue": len([p for p in performances.values() if p > 1600]),
            "yellow": len([p for p in performances.values() if p > 2000]),
            "orange": len([p for p in performances.values() if p > 2400]),
            "red": len([p for p in performances.values() if p > 2800]),
            **performances,
        })
        setting_lite_data.append({
            "experiment_name": exp_name,
            "model": model,
            "code_language": code_lang,
            "average_performance_short": sum(performances_short_lite.values()) / len(performances_short_lite),
            "average_performance_long": sum(performances_long_lite.values()) / len(performances_long_lite),
            "average_performance": avg_perf_lite,
            "average_performance_rank": avg_perf_rank_lite,
            "average_performance_rank_percentile": avg_perf_rank_lite / num_actives,
            "rating": rating_lite,
            "rating_rank": rating_rank_lite,
            "rating_rank_percentile": rating_rank_lite / num_actives,
            "average_performance": sum(performances_lite.values()) / len(performances_lite),
            "total_cost": total_cost_lite / len(performances_lite),
            "total_response_count": total_response_count_lite / len(performances_lite),
            "cost_per_response": total_cost_lite / total_response_count_lite,
            "brown": len([p for p in performances_lite.values() if p > 400]),
            "green": len([p for p in performances_lite.values() if p > 800]),
            "cyan": len([p for p in performances_lite.values() if p > 1200]),
            "blue": len([p for p in performances_lite.values() if p > 1600]),
            "yellow": len([p for p in performances_lite.values() if p > 2000]),
            "orange": len([p for p in performances_lite.values() if p > 2400]),
            "red": len([p for p in performances_lite.values() if p > 2800]),
            **performances_lite,
        })
    df_setting_results = pl.DataFrame(setting_data).select(
        [
            pl.col("experiment_name"),
            pl.col("model"),
            pl.col("code_language"),
            pl.col("average_performance_short"),
            pl.col("average_performance_long"),
            pl.col("average_performance"),
            pl.col("average_performance_rank"),
            pl.col("average_performance_rank_percentile"),
            pl.col("rating"),
            pl.col("rating_rank"),
            pl.col("rating_rank_percentile"),
            pl.col("total_cost"),
            pl.col("total_response_count"),
            pl.col("cost_per_response"),
            pl.col("brown"),
            pl.col("green"),
            pl.col("cyan"),
            pl.col("blue"),
            pl.col("yellow"),
            pl.col("orange"),
            pl.col("red"),
        ] + [pl.col(problem_id) for problem_id in stats_utils.PROBLEM_IDS]
    ).sort("experiment_name")
    df_setting_lite_results = pl.DataFrame(setting_lite_data).select(
        [
            pl.col("experiment_name"),
            pl.col("model"),
            pl.col("code_language"),
            pl.col("average_performance_short"),
            pl.col("average_performance_long"),
            pl.col("average_performance"),
            pl.col("average_performance_rank"),
            pl.col("average_performance_rank_percentile"),
            pl.col("rating"),
            pl.col("rating_rank"),
            pl.col("rating_rank_percentile"),
            pl.col("total_cost"),
            pl.col("total_response_count"),
            pl.col("cost_per_response"),
            pl.col("brown"),
            pl.col("green"),
            pl.col("cyan"),
            pl.col("blue"),
            pl.col("yellow"),
            pl.col("orange"),
            pl.col("red"),
        ] + [pl.col(problem_id) for problem_id in stats_utils.PROBLEM_IDS_LITE]
    ).sort("experiment_name")
    dfs[setting_name] = df_setting_results
    dfs[setting_name + "_lite"] = df_setting_lite_results


pickle.dump((results, dfs), (ROOT_DIR / "notebooks" / "preprocessed_results.pkl").open("wb"))
