import os
import re
import sys
import json
import tempfile
import torch
import random
import numpy as np
import syncode.infer as infer
import arviz as az
import xarray as xr
from transformers import AutoTokenizer
import tiktoken
from typing import Tuple, Union

import xyz

def set_seed(seed: int = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Ensure commons on path
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
# if parent_dir not in sys.path:
#     sys.path.append(parent_dir)
# from commons.model_pymc import models_info, build_prompt_generic

# ── Globals ───────────────────────────────────────────────────────────────────
# reliability_aggregate = {}   # model_name → llm_model → seed → diagnostics
# token_count_aggregate = {}   # model_name → llm_model → { seeds: {seed: tokens}, cumulative: {seed: cumulative_tokens} }

# ── Utilities ─────────────────────────────────────────────────────────────────

def convert_np_types(obj):
    if isinstance(obj, (np.integer, np.int_)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float_)):
        return float(obj)
    elif isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, xr.DataArray):
        return obj.item() if obj.size == 1 else obj.values.tolist()
    elif isinstance(obj, dict):
        return {k: convert_np_types(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_np_types(i) for i in obj]
    else:
        return obj

def save_reliability_aggregate(aggregate_dict, output_path):
    safe = convert_np_types(aggregate_dict)
    with open(output_path, "w") as f:
        json.dump(safe, f, indent=2)
    print(f"Saved reliability aggregate → {output_path}")

def check_model_reliability(
    idata: az.InferenceData,
    r_hat_threshold: float = 1.05,
    min_ess_bulk: int = 400,
    min_ess_tail: int = 100,
    max_pareto_k: float = 0.7,
    max_prop_k: float = 0.20,
):
    scores = {}
    reasons = []

    summary = az.summary(idata)
    max_r_hat = summary["r_hat"].max()
    scores["r_hat"] = int(max_r_hat < r_hat_threshold)
    if scores["r_hat"] == 0:
        reasons.append(f"R‑hat={max_r_hat:.3f}>={r_hat_threshold}")

    min_ess_bulk_val = summary["ess_bulk"].min()
    scores["ess_bulk"] = int(min_ess_bulk_val >= min_ess_bulk)
    if scores["ess_bulk"] == 0:
        reasons.append(f"ESS_bulk={min_ess_bulk_val:.1f}<{min_ess_bulk}")

    if "ess_tail" in summary.columns:
        min_ess_tail_val = summary["ess_tail"].min()
        scores["ess_tail"] = int(min_ess_tail_val >= min_ess_tail)
        if scores["ess_tail"] == 0:
            reasons.append(f"ESS_tail={min_ess_tail_val:.1f}<{min_ess_tail}")
    else:
        min_ess_tail_val = None
        scores["ess_tail"] = 0
        reasons.append("ESS_tail N/A")

    n_div = int(idata.sample_stats["diverging"].sum())
    scores["divergences"] = int(n_div == 0)
    if scores["divergences"] == 0:
        reasons.append(f"{n_div} divergences")

    bfmi_vals = az.bfmi(idata)
    scores["bfmi"] = int((bfmi_vals > 0.3).all())
    if scores["bfmi"] == 0:
        reasons.append(f"Low BFMI={bfmi_vals}")

    try:
        loo_res = az.loo(idata, pointwise=True)
        scores["loo_success"] = 1
        pareto_k = loo_res.pareto_k
        prop_high = np.mean(pareto_k > max_pareto_k)
        scores["pareto_k"] = int(prop_high <= max_prop_k)
        if scores["pareto_k"] == 0:
            n_bad = int((pareto_k > max_pareto_k).sum())
            total = len(pareto_k)
            reasons.append(f"{n_bad}/{total} k>{max_pareto_k}")
    except Exception as e:
        scores["loo_success"] = 0
        scores["pareto_k"] = 0
        reasons.append(f"LOO error: {e}")

    try:
        loo_final = az.loo(idata, pointwise=True)
        elpd_loo = loo_final.elpd_loo
        loo_se = loo_final.se
        scores["elpd_loo_success"] = 1
    except Exception as e:
        elpd_loo = None
        loo_se = None
        scores["elpd_loo_success"] = 0
        reasons.append(f"ELPD error: {e}")

    total_checks = len(scores)
    reliability_score = sum(scores.values())

    diagnostics = {
        "reliability_score": reliability_score,
        "max_score": total_checks,
        "individual_scores": scores,
        "reasons": reasons,
        "max_r_hat": max_r_hat,
        "min_ess_bulk": min_ess_bulk_val,
        "min_ess_tail": min_ess_tail_val,
        "n_divergent": n_div,
        "bfmi_values": bfmi_vals.tolist() if hasattr(bfmi_vals, "tolist") else bfmi_vals,
        "prop_high_pareto_k": float(prop_high) if 'prop_high' in locals() else None,
        "elpd_loo": elpd_loo,
        "loo_se": loo_se,
    }
    return reliability_score, diagnostics

def get_new_experiment_folder(base_dir="expts-org", prefix="", modelsize="medium"):
    base = os.path.join(base_dir, modelsize)
    os.makedirs(base, exist_ok=True)
    existing = os.listdir(base)
    nums = [int(m.group(1)) for f in existing if (m := re.match(rf"{prefix}(\d+)$", f))]
    new = max(nums)+1 if nums else 1
    folder = os.path.join(base, f"{prefix}{new}")
    os.makedirs(folder, exist_ok=True)
    return folder

def store_token_count_aggregate(seed, dataset, llm_model, token_count, agg):
    agg.setdefault(dataset, {})
    agg[dataset].setdefault(llm_model, {"seeds": {}, "cumulative": {}})
    agg[dataset][llm_model]["seeds"][seed] = token_count
    prev_seeds = [s for s in agg[dataset][llm_model]["cumulative"] if s < seed]
    prev_val = max((agg[dataset][llm_model]["cumulative"][s] for s in prev_seeds), default=0)
    agg[dataset][llm_model]["cumulative"][seed] = prev_val + token_count

def save_token_count_aggregate(agg, output_path):
    with open(output_path, "w") as f:
        json.dump(agg, f, indent=2)
    print(f"Saved token counts → {output_path}")

def run_pymc_code(full_code: str):
    with tempfile.TemporaryDirectory() as td:
        path = os.path.join(td, "model.py")
        with open(path, "w") as f:
            f.write(full_code)
        try:
            ns = {}
            exec(full_code, ns)
            return True, ns
        except Exception as e:
            print("Exec error:", e)
            return False, str(e)

def insert_model_code(
    boilerplate: str,
    raw_snippet: Union[str, list],
    trace_name: str,
    model_name: str = "gpt-3.5-turbo"
) -> Tuple[str, int]:
    if isinstance(raw_snippet, list):
        raw_snippet = "".join(raw_snippet)
    snippet = raw_snippet.replace("```", "")

    proc_lines = []
    for L in snippet.splitlines():
        if not L.strip():
            proc_lines.append("")
        else:
            proc_lines.append("\t" + L.lstrip())
        if "pm.sample" in L:
            break
    snippet_to_count = "\n".join(proc_lines)

    try:
        enc = tiktoken.encoding_for_model(model_name)
        tok_count = len(enc.encode(snippet_to_count))
    except KeyError:
        tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        tok_count = len(tok(snippet_to_count, return_tensors="pt").input_ids[0])

    lines = []
    inserted = False
    for bl in boilerplate.split("\n"):
        lines.append(bl.lstrip())
        if bl.strip().startswith("with pm.Model() as m:"):
            lines.append(snippet_to_count)
            inserted = True
    if not inserted:
        raise ValueError("`with pm.Model() as m:` not found in boilerplate")

    lines.append("\t# Posterior diagnostics")
    lines.append(f"\tsummary = az.summary({trace_name})")
    full_code = "\n".join(lines)
    return full_code, tok_count

def extract_trace_name(snippet):
    code = snippet[0] if isinstance(snippet, list) else snippet
    m = re.search(r"(\w+)\s*=\s*pm\.sample", code)
    return m.group(1) if m else "trace"

# ── Main experiment ───────────────────────────────────────────────────────────
def main_experiment(parent_folder, temperature=None, seed=1, do_sample=True, modelsize="medium"):
    set_seed(seed)
    exp_folder = os.path.join(parent_folder, f"seed_{seed}")
    os.makedirs(exp_folder, exist_ok=True)

    all_results = []

    small_models = ["microsoft/Phi-3.5-mini-instruct", "Qwen/Qwen2.5-Coder-3B"]
    medium_models = [
        "meta-llama/Meta-Llama-3-8B",
        "google/codegemma-7b",
        "Qwen/Qwen2.5-Coder-7B",
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    ]
    reasoning_models = small_models

    if modelsize == "small":
        models = small_models
    elif modelsize == "medium":
        models = medium_models
    else:
        models = reasoning_models

    for llm in models:
        llm_folder = os.path.join(exp_folder, llm.replace("/", "_"))
        os.makedirs(llm_folder, exist_ok=True)

        local_llm = infer.Syncode(
            mode="original",
            model=llm,
            do_sample=do_sample,
            temperature=temperature,
            grammar="/home/madhav/madhav/refinegen/refinegen/itergen/itergen/syncode/syncode/parsers/grammars/python_grammar.lark",
            device="cuda",
            parse_output_only=False,
            max_new_tokens=400,
        )

        for entry in models_info:
            mn = entry["name"]
            model_folder = os.path.join(llm_folder, mn)
            os.makedirs(model_folder, exist_ok=True)

            prompt = build_prompt_generic(entry)
            with open(os.path.join(model_folder, "prompt.txt"), "w") as pf:
                pf.write(prompt)

            snippet = local_llm.infer(prompt)
            with open(os.path.join(model_folder, "snippet.txt"), "w") as sf:
                sf.write(str(snippet))

            trace_name = extract_trace_name(snippet)
            full_code, tok_count = insert_model_code(entry["template_code"], snippet, trace_name, llm)
            with open(os.path.join(model_folder, "final_code.py"), "w") as cf:
                cf.write(full_code)

            compiled, output = run_pymc_code(full_code)
            with open(os.path.join(model_folder, "exec_output.txt"), "w") as ef:
                ef.write(str(output))

            rel_score = None
            elpd_val = None
            diagnostics = None

            if compiled and isinstance(output, dict) and trace_name in output:
                idata = output[trace_name]
                rel_score, diagnostics = check_model_reliability(idata)
                elpd_val = diagnostics["elpd_loo"] if diagnostics["elpd_loo"] is not None else float("-inf")
                print(mn, llm, seed)
                reliability_aggregate.setdefault(mn, {}).setdefault(llm, {})[seed] = {
                    "reliability_score": rel_score,
                    "elpd": elpd_val,
                    "diagnostics": diagnostics
                }

            store_token_count_aggregate(seed, mn, llm.replace("/", "_"), tok_count, token_count_aggregate)

            all_results.append({
                "compiled": compiled,
                "full_code": full_code,
                "reliability_score": rel_score if rel_score is not None else -1,
                "elpd_loo": elpd_val if elpd_val is not None else float("-inf"),
                "model_name": mn,
                "llm_model": llm,
                "token_count": tok_count,
                "folder": model_folder
            })

        del local_llm
        torch.cuda.empty_cache()

    # per-seed best is no longer saved here; cross-seed selection happens later
    save_token_count_aggregate(token_count_aggregate, os.path.join(parent_folder, "token_count.json"))

    return {"seed": seed}

# ── Run multiple seeds & cross-seed best selection ───────────────────────────
def run_multiple_seeds(num_seeds=10, temperature=0.3, modelsize="medium"):
    small_models = [
        "microsoft/Phi-3.5-mini-instruct",
        "Qwen/Qwen2.5-Coder-3B",
    ]
    medium_models = [
        "meta-llama/Meta-Llama-3-8B",
        "google/codegemma-7b",
        "Qwen/Qwen2.5-Coder-7B",
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    ]
    reasoning_models = small_models

    if modelsize == "small":
        llm_models = small_models
    elif modelsize == "medium":
        llm_models = medium_models
    else:
        llm_models = reasoning_models

    multi_folder = get_new_experiment_folder(
        base_dir="expts-org", prefix="", modelsize=modelsize
    )
    with open(os.path.join(multi_folder, "desc.txt"), "w") as df:
        df.write(f"Seeds={num_seeds}, temp={temperature}, size={modelsize}\n")

    # run all seeds
    for s in range(1, num_seeds + 1):
        main_experiment(
            parent_folder=multi_folder,
            temperature=temperature,
            seed=s,
            do_sample=True,
            modelsize=modelsize
        )

    # save raw aggregates
    save_reliability_aggregate(
        reliability_aggregate,
        os.path.join(multi_folder, "aggregated_reliability.json")
    )
    save_token_count_aggregate(
        token_count_aggregate,
        os.path.join(multi_folder, "aggregated_token_count.json")
    )

    datasets = [entry["name"] for entry in models_info]

    # cross-seed best per dataset×LLM
    cross_seed_best = []
    for dataset in datasets:
        for llm in llm_models:
            seeds_dict = reliability_aggregate.get(dataset, {}).get(llm, {})
            if seeds_dict:
                best_seed, best_metrics = max(
                    seeds_dict.items(),
                    key=lambda kv: (kv[1]["reliability_score"], kv[1]["elpd"])
                )
                code_path = os.path.join(
                    multi_folder,
                    f"seed_{best_seed}",
                    llm.replace("/", "_"),
                    dataset,
                    "final_code.py"
                )
                entry = {
                    "dataset":     dataset,
                    "llm_model":   llm,
                    "seed":        best_seed,
                    "reliability": best_metrics["reliability_score"],
                    "elpd_loo":    best_metrics["elpd"],
                    "code_path":   code_path
                }
            else:
                entry = {
                    "dataset":     dataset,
                    "llm_model":   llm,
                    "seed":        None,
                    "reliability": None,
                    "elpd_loo":    None,
                    "code_path":   None
                }
            cross_seed_best.append(entry)

    with open(os.path.join(multi_folder, "cross_seed_best_programs.json"), "w") as f:
        json.dump(cross_seed_best, f, indent=2)
    print(f"Saved {len(cross_seed_best)} cross-seed best entries")

    # full leaderboard of all programs
    leaderboard = []
    for seed in range(1, num_seeds + 1):
        for dataset in datasets:
            for llm in llm_models:
                metrics = reliability_aggregate.get(dataset, {}).get(llm, {}).get(seed)
                reliability = metrics["reliability_score"] if metrics else None
                elpd = metrics["elpd"] if metrics else None
                cumulative = token_count_aggregate.get(dataset, {}) \
                                                  .get(llm, {}) \
                                                  .get("cumulative", {}) \
                                                  .get(seed, None)
                code_path = os.path.join(
                    multi_folder,
                    f"seed_{seed}",
                    llm.replace("/", "_"),
                    dataset,
                    "final_code.py"
                )
                if not os.path.exists(code_path):
                    code_path = None

                leaderboard.append({
                    "seed":              seed,
                    "dataset":           dataset,
                    "llm_model":         llm,
                    "reliability_score": reliability,
                    "elpd_loo":          elpd,
                    "cumulative_tokens": cumulative,
                    "code_path":         code_path
                })

    leaderboard.sort(
        key=lambda x: (
            x["reliability_score"] if x["reliability_score"] is not None else -1,
            x["elpd_loo"] if x["elpd_loo"] is not None else float("-inf")
        ),
        reverse=True
    )
    with open(os.path.join(multi_folder, "leaderboard.json"), "w") as f:
        json.dump(leaderboard, f, indent=2)
    print(f"Saved full leaderboard ({len(leaderboard)} entries)")

    # progression files per dataset×LLM
    for dataset in datasets:
        for llm in llm_models:
            cum_map  = token_count_aggregate.get(dataset, {}).get(llm, {}).get("cumulative", {})
            seed_map = reliability_aggregate.get(dataset, {}).get(llm, {})

            progression = []
            for seed in sorted(cum_map.keys()):
                progression.append({
                    "seed":              seed,
                    "cumulative_tokens": cum_map[seed],
                    "reliability_score": seed_map.get(seed, {}).get("reliability_score")
                })

            prog_path = os.path.join(
                multi_folder,
                f"progress_{dataset}_{llm.replace('/', '_')}.json"
            )
            with open(prog_path, "w") as pf:
                json.dump(progression, pf, indent=2)
            print(f"Saved progression for {dataset} × {llm} → {prog_path}")


if __name__ == "__main__":
    try:
        temp = float(sys.argv[1]) if len(sys.argv) > 1 else 0.3
        size = sys.argv[2] if len(sys.argv) > 2 else "medium"
        seeds = int(sys.argv[3]) if len(sys.argv) > 3 else 10
        run_multiple_seeds(num_seeds=seeds, temperature=temp, modelsize=size)
    except Exception as e:
        # print(e)
        # print("Usage: python original.py [temperature] [size] [num_seeds]")
        import pdb; pdb.set_trace()
