from tokenizer_conversion.benchmarking.utils.scoring_utils import load_data, compute_jds, create_latex_table
from collections import defaultdict
import os
import numpy as np
import pandas as pd
import pickle

if __name__ == "__main__":
    data_dir = ""
    all_dfs = []

    compare_genlm = False
    genlm_dir =''

    def load_genlm(model_name):
        genlm_dict = {
            "meta-llama/Llama-3.1-8B": "",
            "meta-llama/Llama-3.2-1B": "",
            "gpt2-large": "",
        }
        data = os.path.join(genlm_dir, genlm_dict[model_name])
        print(f"Loading genlm data from: {data}")
        with open(data, "rb") as fh:
            genlm_dists =  pickle.load(fh)
        return genlm_dists
        

    for filename in os.listdir(data_dir):
        if not filename.endswith(".pkl"):
            continue

        data = os.path.join(data_dir, filename)
        print(f"Loading data from: {data}")
        if data_dir == "results_dna/":
            print(data)
            max_cap = data.split("dna_max")[-1].strip(".pkl")
        else:
            max_cap = ""
        logp_nexts, Ks, metadata, stats = load_data(data)

        model_name      = metadata["model_name"]

        print(model_name)
        paragraphs      = metadata["paragraphs"]
        transducer_name = metadata["transducer_name"]

        print(len(logp_nexts))
        print(logp_nexts.keys())
        p_nexts = defaultdict(list)
        for K in Ks:
            for i in range(paragraphs):
                
                arr = np.asarray(logp_nexts[K][i], dtype=np.float32)
                np.exp(arr, out=arr)
                p_nexts[K].extend([{j: v for j, v in enumerate(row)} for row in arr])

        if compare_genlm:
            genlm_dists = load_genlm(model_name)
            p_nexts[0] = []
            for i in range(paragraphs):
                val = [list(p.values()) for p in genlm_dists[i]]
                arr = np.asarray(val, dtype=np.float32)
                np.exp(arr, out=arr)
                p_nexts[0].extend([{j: v for j, v in enumerate(row)} for row in arr])
            Ks.append(0)
        df_js = compute_jds(p_nexts, Ks, metadata, stats, lower=True).copy()

        rename_cols = {
            "mean_jsd": "mean_metric",
            "jsd_ci_lower": "metric_ci_lower",
            "jsd_ci_upper": "metric_ci_upper",
            "mean_chars_per_sec": "chars_per_sec",
            "chars_per_sec_ci_lower": "speed_ci_lower",
            "chars_per_sec_ci_upper": "speed_ci_upper",
        }
        df_js.rename(columns={k: v for k, v in rename_cols.items() if k in df_js.columns}, inplace=True)
        
        if data_dir == "results_dna/":
            df_js["model"] = model_name+"-"+max_cap
        else:
            df_js["model"] = model_name


        df_js["max_cap"] = max_cap

        print_model_name = model_name.replace("/", "_").replace("-", "_")
        out_csv = f"{data_dir}/{print_model_name}_{transducer_name}_{max_cap}.csv"
        df_js.to_csv(out_csv, index=False)

        all_dfs.append(df_js)

    if not all_dfs:
        raise SystemExit("No .pkl files found in results")

    combined_df = pd.concat(all_dfs, ignore_index=True)
    if {"model", "K"}.issubset(combined_df.columns):
        combined_df = (
            combined_df.sort_values(["model", "K"])
                       .drop_duplicates(subset=["model", "K"], keep="last")
        )

    combined_out = os.path.join(data_dir, f"ALL_MODELS_{transducer_name}_jsd_results.csv")
    combined_df.to_csv(combined_out, index=False)
    print(f"Wrote combined CSV to: {combined_out}")

    print(create_latex_table(combined_df, "JSD"))