import json
from pathlib import Path
from typing import Literal, Union
from dotenv import load_dotenv
from llm_utils.textgen_api.textgen_api import TextGenApi
from natsort import natsorted
import pandas as pd


def main():
    load_dotenv()
    variants = ["guan", "sadegh", "lodge", "cluster"]
    domains = ["logistics", "household", "fb-lamp"]
    llms = ["openrouter:meta-llama/llama-4-scout", "openrouter:meta-llama/llama-4-maverick", "gpt4.1-mini"]

    v_mapping = {
        "lodge": "hi-tamp",
        "cluster": "cluster-intersect",
    }

    evals = []
    for domain in domains:
        results_dir = (
            Path(__file__).parent.parent / ("furniturebench" if domain.startswith("fb-") else "ipc") / "results"
        )
        for llm in llms:
            for variant in variants:
                llm_name = TextGenApi.default(llm).connections.connections[0].model_dir
                v_name = v_mapping.get(variant, variant)
                result_dir_s = results_dir / domain / llm_name / v_name
                if variant == "sadegh":
                    seeds = "*docstrings/seed_*"
                elif variant == "lodge":
                    seeds = "planning-*/sample-*" if domain.startswith("fb-") else "iclr-w-dk-w-ai-shared/sample-*"
                elif variant == "guan":
                    seeds = "docstrings/sample-*"
                elif variant == "lionel":
                    seeds = "no-docstrings"
                elif variant == "cluster":
                    seeds = "*-based/sample-*"
                else:
                    raise NotImplementedError()

                if isinstance(seeds, str):
                    seeds = [seeds]

                unmapped_seeds = [
                    s.relative_to(result_dir_s).as_posix()
                    for seed in seeds
                    for s in (natsorted(result_dir_s.glob(seed)) if seed != "" else [result_dir_s])
                    if s.is_dir()
                ]

                for seed in unmapped_seeds:
                    result_dir = result_dir_s / seed
                    if not result_dir.exists():
                        # print(f"Skipping {result_dir} as it does not exist.")
                        continue
                    assert result_dir.is_dir(), result_dir
                    eval_file = result_dir / "eval-random-walks.json"
                    if not eval_file.exists():
                        # print(f"Skipping {eval_file} as it does not exist.")
                        continue

                    eval_data_json = json.loads(eval_file.read_text())
                    if variant == "lodge" and domain == "household":
                        df = pd.DataFrame(eval_data_json["aux"])
                        print(eval_file)
                        df["input_tokens_uncached"] = df["input_tokens"] - df["input_tokens_cached"]

                        prices = {
                            "input_tokens": 0,
                            "input_tokens_uncached": 0.4,
                            "input_tokens_cached": 0.1,
                            "output_tokens": 1.6,
                            "output_tokens_cached": 0,
                        }

                        tokens = (df[[c for c in df.columns if "tokens" in c]].sum() / 1e6).to_dict()
                        print(tokens)
                        print("Price $%.2f" % sum([tokens[c] * prices[c] for c in tokens]))
                    del eval_data_json["aux"]
                    eval_data = pd.Series(eval_data_json)
                    eval_data = eval_data.fillna(0)
                    evals.append({"variant": variant, "domain": domain, "llm": llm_name, "seed": seed, **eval_data})
    # print(evals)
    df = pd.DataFrame(evals)
    df = df.set_index(["domain", "llm", "variant", "seed"])

    df_performance = df[[c for c in df.columns if "tokens" not in c]]
    df_tokens = df[[c for c in df.columns if "tokens" in c]].fillna(0)
    df_tokens = (df_tokens / 1000).round(0).astype(int)  # .map(lambda v: f"{v}k")

    df_sum_tokens = df_tokens.copy()
    map_f = lambda x: f"{x:.2f}"

    mode: Union[Literal["detailed"], Literal["per-run"], Literal["grouped"], Literal["grouped-v2"]] = "grouped-v2"
    if mode == "detailed":
        print(df_performance)
        print("\n\n")
        print("Tokens used:")
        df_sum_tokens["total_tokens"] = df_sum_tokens["input_tokens"] + df_sum_tokens["output_tokens"]
        df_sum_tokens = df_sum_tokens.map(lambda v: f"{v}k")
        print(df_sum_tokens)
    elif mode == "per-run":
        # df_performance = df_performance[["planning_sr", "harmonic_mean"]].round(2)
        df_performance = df_performance[["planning_sr", "harmonic_mean", "t_to_gen_frac", "gen_to_t_frac"]].round(2)
        df_sum_tokens = df_sum_tokens.drop(columns=["input_tokens_cached", "output_tokens_cached"])
        df_sum_tokens = df_sum_tokens.map(lambda v: f"{v}k")
        # df_total = pd.concat((df_performance, df_sum_tokens), axis=1)
        df_total = df_performance
        print(df_total)
        df_total.to_csv("evaluation_results.csv")
    elif mode == "grouped":
        print("test")
        df = df.reset_index()
        df["seed-group"] = df["seed"].str.split("/").str[0]
        df = df.set_index(["domain", "llm", "variant", "seed-group"])
        # df = df[['t_to_gen_frac', 'gen_to_t_frac', 'planning_sr', 'input_tokens', 'output_tokens', 'harmonic_mean']]
        df = df[["planning_sr", "harmonic_mean", "t_to_gen_frac", "gen_to_t_frac"]]
        grouped_df = df.groupby(level=(0, 1, 2, 3), sort=False)
        sum_df = grouped_df.mean().round(2)
        sum_df["# Samples"] = grouped_df.size()
        # print(sum_df)
        sum_df["metric"] = sum_df["harmonic_mean"].map(map_f) + " / " + sum_df["planning_sr"].map(map_f)
        print(sum_df[["# Samples", "metric"]])
    elif mode == "grouped-v2":
        df = df.reset_index()
        # df["name"] = df["variant"] + "-" + df["seed"].str.split("/").str[0]
        df["seed-group"] = df["seed"].str.split("/").str[0]
        df = df.set_index(["domain", "llm", "variant", "seed-group"])
        df = df[["planning_sr", "harmonic_mean"]]
        grouped_df = df.groupby(level=(0, 1, 2, 3), sort=False)
        sum_df = grouped_df.mean().round(2)
        sum_df["# Samples"] = grouped_df.size()
        sum_df["metric"] = sum_df["harmonic_mean"].map(map_f) + " / " + sum_df["planning_sr"].map(map_f)
        sum_df = sum_df.unstack(level=(2, 3))
        print(sum_df["# Samples"])
        print("-" * 100)
        print(sum_df["metric"])
        # print(sum_df[["# Samples", "metric"]])
    else:
        raise NotImplementedError()


if __name__ == "__main__":
    main()
