import os
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from param_counter import count_params, count_params_with_rec, warmup_hist
from transformers import AutoModelForCausalLM
import copy
from matplotlib import font_manager
from natsort import natsorted

def import_times_new_roman(this_font_manager=font_manager, this_plt=plt, font_size=16):
    this_font_manager.fontManager.addfont(f"paper_plots/Times New Roman.ttf")
    this_plt.rcParams["font.family"] = "Times New Roman"
    this_plt.rcParams["font.size"] = font_size

task_to_key = {
    "arc_challenge": "acc_norm,none", 
    "arc_easy": "acc_norm,none",
    "hellaswag": "acc_norm,none",
    "lambada_openai": "acc,none",
    "mmlu": "acc,none",
    "openbookqa": "acc_norm,none",
    "piqa": "acc_norm,none",
    "social_iqa": "acc,none",
    "winogrande": "acc,none",

    "cola": "mcc,none",
    "commonsense_qa": "acc,none",
    "glue": "acc,none",
    "logiqa": "acc_norm,none",
    "mathqa": "acc_norm,none",
    "mc_taco": "f1,none",
    "webqs": "exact_match,none",
    "inverse_scaling_mc": "acc_norm,none",

    "gsm8k": "exact_match,flexible-extract",
    "open_openbookqa": "acc_norm,none",
}

task_to_random = {
    "arc_challenge": 0.25, 
    "arc_easy": 0.25,
    "hellaswag": 0.25,
    "lambada_openai": 0.25,
    "mmlu": 0.25,
    "openbookqa": 0.25,
    "open_openbookqa": 0.25,
    "piqa": 0.5,
    "social_iqa": 0.25,
    "winogrande": 0.5,

    "cola": 0.25,
    "commonsense_qa": 0.25,
    "glue": 0.25,
    "logiqa": 0.25,
    "mathqa": 0.25,
    "mc_taco": 0.25,
    "webqs": 0.25,
    "inverse_scaling_mc": 0.25,
    "gsm8k": 0.0,
}

CACHE_FILE = "model_size_cache.json"
def _load_model_size_cache(path: str = CACHE_FILE) -> dict:
    if os.path.exists(path):
        with open(path, "r") as f:
            raw = json.load(f)
        # JSON makes keys strings; convert nested keys back to int
        return {m: {int(k): v for k, v in sizes.items()} for m, sizes in raw.items()}
    return {}

def _save_model_size_cache(cache: dict, path: str = CACHE_FILE):
    with open(path, "w") as f:
        json.dump(cache, f)

model_size_cache = _load_model_size_cache()

def get_flops(all_counts, train_rec, this_chkpt, max_steps, batch_size=2**20, warmup_duration=0.25): # 1M batch size
    if train_rec == 1:
        return all_counts[train_rec]["flops_times_by_6d"] * 6 * this_chkpt * batch_size
    else:
        warmups = warmup_hist(this_chkpt, warmup_duration, max_steps=max_steps, max_rec=train_rec)
        flops_times_checkpoints = sum(all_counts[k]["flops_times_by_6d"] * v for k, v in warmups.items())
        return flops_times_checkpoints * 6 * batch_size
    
def make_df(runs):
    global model_size_cache

    all_data = []
    for dir, short_name, extra_dict in runs:
        chkpt_dirs = os.listdir(dir)
        for chkpt in chkpt_dirs:
            json_files_recursive = glob.glob(f"{dir}/{chkpt}/**/*.json", recursive=True)
            for json_file in json_files_recursive:
                with open(json_file, "r") as f:
                    data = json.load(f)

                model_path = data["model_name"]

                dict_name = "/".join(model_path.split("/")[:-1])
                if dict_name not in model_size_cache.keys():
                    print(f"Adding {dict_name} to model size cache")
                    if ("TinyLlama-1.1B-intermediate-step-1431k-3T" in dir) or ("Llama-3.2-1B-untied" in dir):
                        if model_path.startswith("models/"):
                            model_path = model_path.replace("models/", "")

                    model = AutoModelForCausalLM.from_pretrained(
                        model_path,
                        low_cpu_mem_usage=True,
                        attn_implementation="sdpa",
                        trust_remote_code=True,
                    )
                    if ("TinyLlama-1.1B-intermediate-step-1431k-3T" in dir) or ("non_recur" in dir) or ("Llama-3.2-1B-untied" in dir):
                        # total_params = sum(p.numel() for p in model.parameters())
                        # total_params = sum(
                        #     p.numel() for name, p in model.named_parameters()
                        #     if not (name.startswith("transformer.wte") or name.startswith("lm_head"))
                        # )
                        total_params = sum(
                            p.numel() for name, p in model.named_parameters()
                            if not (("embed_tokens" in name) or ("lm_head" in name))
                        )
                        # for name, p in model.named_parameters():
                        #     print(name)
                        # exit()
                        model_size_cache[dict_name] = {1: {"total_not_emb_or_lm_head": total_params, "flops_times_by_6d": total_params}}
                    else:
                        counts = count_params(model)
                        size_dict = {k: count_params_with_rec(copy.deepcopy(counts), num_rec=k) for k in range(1,33)}
                        model_size_cache[dict_name] = size_dict

                if ("TinyLlama-1.1B-intermediate-step-1431k-3T" in dir) or ("Llama-3.2-1B-untied" in dir):
                    num_rec = 1
                    this_chkpt = 0
                elif "non_recur" in dir:
                    num_rec = 1
                    this_chkpt = int(chkpt.replace("model_only_chkpt_",""))
                else:
                    print(json_file)
                    try:
                        num_rec = data["configs"]["hellaswag"]["metadata"]["mean_recurrence"]
                    except:
                        num_rec = data["configs"]["gsm8k"]["metadata"]["mean_recurrence"]
                    this_chkpt = int(chkpt.replace("model_only_chkpt_",""))

                for k, v in data["results"].items():
                    if k in task_to_key.keys():
                        task_key = task_to_key[k]
                        all_data.append(
                            {
                                "model": short_name,
                                "chkpt": this_chkpt,
                                "task": k,
                                "acc": v[task_key],
                                "stderr": v[task_key.replace(",","_stderr,")],
                                "train_rec": extra_dict["train_rec"],
                                "num_rec": num_rec,
                                "cooldown": "cooldown" in short_name,
                                "effective_params": model_size_cache[dict_name][num_rec]["total_not_emb_or_lm_head"],
                                "FLOPs": get_flops(model_size_cache[dict_name], extra_dict["train_rec"], this_chkpt,  max_steps=extra_dict.get("max_steps", 25_000), warmup_duration=extra_dict.get("warmup_duration", 0.25))
                            }
                        )
    _save_model_size_cache(model_size_cache)
    return pd.DataFrame(all_data)

def add_table_section(
        df,
        sec_name, 
        task_order = ['arc_easy', 'arc_challenge', 'hellaswag', 'winogrande', 'mmlu', 'piqa', 'openbookqa'],
        rec_order = [1, 2, 4, 8, 16, 32],
    ):
    body_lines = []
    body_lines.append(r'\midrule')
    body_lines.append(rf'\multicolumn{{{len(task_order)+1}}}{{c}}{{{sec_name}}} \\')
    body_lines.append(r'\midrule')
    piv = (df[df['model']==sec_name]
           .pivot_table(index='num_rec', columns='task', values='acc', aggfunc='first'))
    for rec in rec_order:
        row = piv.loc[rec] if rec in getattr(piv, 'index', []) else pd.Series(dtype=float)
        cells = ' & '.join(f"{row.get(t)*100:.1f}" if row.get(t) is not None else "" for t in task_order)
        body_lines.append(f'{rec} & {cells} \\\\')
    return body_lines

def plot(runs, save_prefix, save_df=False):
    df = make_df(runs)
    if save_df:
        df.to_json(f"paper_plots/{save_prefix}.jsonl", orient="records", lines=True)
    print(df)
    # if save_prefix == "tinyllama_long_run_table":
    #     print(df[df["chkpt"] == 28000][["model", "task", "acc", "num_rec"]])
    #     this_df = df.copy()
    #     this_df = this_df[this_df["chkpt"] == 28000][["model", "task", "acc", "num_rec"]]
    #     for row in add_table_section(this_df, "Takase"):
    #         print(row)
    #     for row in add_table_section(this_df, "Huginn Llama"):
    #         print(row)
    if save_prefix in ["tinyllama_long_run_table", "llama_long", "tinyllama_ablate_rec_long", "tinyllama_nemotron_fineweb_2"]:
        task_order = ['arc_easy', 'arc_challenge', 'hellaswag', 'winogrande', 'mmlu', 'piqa', 'openbookqa']
        if save_prefix == "llama_long":
            last_checkpoint = 48_000
        elif save_prefix == "tinyllama_ablate_rec_long":
            last_checkpoint = 50_000
        elif save_prefix == "tinyllama_nemotron_fineweb_2":
            last_checkpoint = 25_000
            task_order = ['arc_easy', 'arc_challenge', 'hellaswag', 'winogrande', 'mmlu', 'piqa', 'openbookqa', 'gsm8k']
        elif save_prefix == "tinyllama_long_run_table":
            last_checkpoint = 28_000

        print(df[df["chkpt"] == last_checkpoint][["model", "task", "acc", "num_rec"]])
        this_df = df.copy()
        this_df = this_df[this_df["chkpt"] == last_checkpoint][["model", "task", "acc", "num_rec"]]
        for model in natsorted(list(this_df["model"].unique())):
            for row in add_table_section(this_df, model, task_order=task_order):
                print(row)
        exit()

    check_df(df)

    num_rec_styles = {
        1: '-',
        2: '--',
        4: '-.',
        8: ':',
        16: (0, (3, 1, 1, 1)),  # dash-dot-dot
        32: (0, (1, 1)),        # densely dotted
        # 64: '-',        # densely dotted
    }

    for task_name in df["task"].unique():
        x_axes = ["chkpt"]
        if task_name == "gsm8k":
            if "size_chkpt" not in df:
                df["size_chkpt"] = df["effective_params"] * df["chkpt"] #* 1_048_576 * 6 # approx 6ND
            x_axes = ["chkpt", "size_chkpt", "FLOPs", "effective_params"]
        for x_axis in x_axes:
            subset = df[df['task'] == task_name]
            print(f"{task_name}: {subset['model'].unique()}")

            plt.figure(figsize=(8, 4))
            plt.grid()

            # Assign consistent color per model
            models = sorted(subset["model"].unique())
            # model_colors = dict(zip(models, plt.cm.tab10.colors))  # up to 10 distinct colors
            cmap = plt.get_cmap('tab20')  # 20 distinct colors; cycles if >20
            model_colors = {m: cmap(i % cmap.N) for i, m in enumerate(models)}

            for model_name in models:
                model_subset = subset[subset["model"] == model_name]
                for num_rec in sorted(model_subset["num_rec"].unique()):
                    style = num_rec_styles.get(num_rec, '-')  # fallback to solid
                    data = model_subset[model_subset["num_rec"] == num_rec].sort_values(x_axis)

                    plt.errorbar(
                        data[x_axis],
                        data["acc"],
                        # yerr=data["stderr"],
                        label=f"{model_name} (num_rec={num_rec})",
                        color=model_colors[model_name],
                        marker='o',
                        capsize=4,
                        linestyle=style
                    )

            plt.axhline(y=task_to_random[task_name], color='black', linestyle='--', label='Random Baseline')

            # plt.title(f"Accuracy vs Checkpoint for Task: {task_name}")
            plt.xlabel("Step" if x_axis == "chkpt" else x_axis)
            plt.ylabel("Accuracy")

            # Legend for models (color)
            model_legend_handles = [
                matplotlib.lines.Line2D([0], [0], color=model_colors[model], lw=2, label=model)
                for model in models
            ]

            # Legend for num_rec (linestyle)
            num_rec_legend_handles = [
                matplotlib.lines.Line2D([0], [0], color='black', lw=2, linestyle=style, label=f"num_rec={num_rec}")
                for num_rec, style in num_rec_styles.items()
            ]

            second_legend = plt.legend(
                handles=num_rec_legend_handles,
                title="Num Rec (Line Style)",
                loc='upper center',
                bbox_to_anchor=(0.75, -0.1),
                ncol=2,
                frameon=False
            )
            plt.gca().add_artist(second_legend)

            first_legend = plt.legend(
                handles=model_legend_handles,
                title="Model (Color)",
                loc='upper center',
                bbox_to_anchor=(0.2, -0.1),
                ncol=2,
                frameon=False
            )
            # plt.gca().add_artist(first_legend)

            axis_name = f"_{x_axis}" if x_axis != "chkpt" else ""
            plt.savefig(f"eval_plots/{save_prefix}_{task_name}{axis_name}.pdf", bbox_inches="tight")
            plt.clf()
            print(f"eval_plots/{save_prefix}_{task_name}{axis_name}.pdf")

def check_df(df):
    expected_chkpts = list(range(1000, 28001, 1000))
    expected_recurs = [1, 2, 4, 8, 16, 32]

    missing_entries = []

    for model in df["model"].unique():
        if "cooldown" in model:
            continue
        for task in df["task"].unique():
            for chkpt in expected_chkpts:
                for rec in expected_recurs:
                    exists = df[
                        (df["model"] == model) &
                        (df["task"] == task) &
                        (df["chkpt"] == chkpt) &
                        (df["num_rec"] == rec)
                    ]
                    if exists.empty:
                        missing_entries.append((model, chkpt, rec))
    print("==== MISSING ====")
    print(set(missing_entries))

if __name__ == "__main__":
    runs = [
        (
            "lm_evals/huginn_llama/my_model_name",
            "My leged Name",
            {"train_rec": 32, "max_steps": 120_000, "warmup_duration":0.25}
        ),
    ]
    save_prefix = "save_name"
    plot(runs, save_prefix, save_df=True)