import os
import re
import math
import csv
import tempfile
import shutil
import subprocess
import time
from math import comb
import collections
import multiprocessing

import torch
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from tqdm import tqdm

# Config placeholders (set these via environment before running)
HF_TOKEN = os.environ.get("HF_TOKEN", "to enter values of HF_TOKEN")
MODEL_NAME = os.environ.get("MODEL_NAME", "to enter values of MODEL_NAME")

# --- 9 LoRA paths for Python (placeholders) ---
TENSORS_DIR_PRETRAINED_PY = ""  # pretrained uses load_pretrained; left empty placeholder
TENSORS_DIR_FULL_MLPATT_PY = os.environ.get("TENSORS_DIR_FULL_MLPATT_PY", "to enter values of TENSORS_DIR_FULL_MLPATT_PY")
TENSORS_DIR_FULL_ATT_PY    = os.environ.get("TENSORS_DIR_FULL_ATT_PY",    "to enter values of TENSORS_DIR_FULL_ATT_PY")
TENSORS_DIR_FULL_MLP_PY    = os.environ.get("TENSORS_DIR_FULL_MLP_PY",    "to enter values of TENSORS_DIR_FULL_MLP_PY")

TENSORS_DIR_TOP1_MLPATT_PY = os.environ.get("TENSORS_DIR_TOP1_MLPATT_PY", "to enter values of TENSORS_DIR_TOP1_MLPATT_PY")
TENSORS_DIR_TOP1_ATT_PY    = os.environ.get("TENSORS_DIR_TOP1_ATT_PY",    "to enter values of TENSORS_DIR_TOP1_ATT_PY")
TENSORS_DIR_TOP1_MLP_PY    = os.environ.get("TENSORS_DIR_TOP1_MLP_PY",    "to enter values of TENSORS_DIR_TOP1_MLP_PY")

TENSORS_DIR_TOP3_MLPATT_PY = os.environ.get("TENSORS_DIR_TOP3_MLPATT_PY", "to enter values of TENSORS_DIR_TOP3_MLPATT_PY")
TENSORS_DIR_TOP3_ATT_PY    = os.environ.get("TENSORS_DIR_TOP3_ATT_PY",    "to enter values of TENSORS_DIR_TOP3_ATT_PY")
TENSORS_DIR_TOP3_MLP_PY    = os.environ.get("TENSORS_DIR_TOP3_MLP_PY",    "to enter values of TENSORS_DIR_TOP3_MLP_PY")

# --- 9 LoRA paths for C++ (placeholders) ---
TENSORS_DIR_PRETRAINED_CPP = ""  # pretrained uses load_pretrained; left empty placeholder
TENSORS_DIR_FULL_MLPATT_CPP = os.environ.get("TENSORS_DIR_FULL_MLPATT_CPP", "to enter values of TENSORS_DIR_FULL_MLPATT_CPP")
TENSORS_DIR_FULL_ATT_CPP    = os.environ.get("TENSORS_DIR_FULL_ATT_CPP",    "to enter values of TENSORS_DIR_FULL_ATT_CPP")
TENSORS_DIR_FULL_MLP_CPP    = os.environ.get("TENSORS_DIR_FULL_MLP_CPP",    "to enter values of TENSORS_DIR_FULL_MLP_CPP")

TENSORS_DIR_TOP1_MLPATT_CPP = os.environ.get("TENSORS_DIR_TOP1_MLPATT_CPP", "to enter values of TENSORS_DIR_TOP1_MLPATT_CPP")
TENSORS_DIR_TOP1_ATT_CPP    = os.environ.get("TENSORS_DIR_TOP1_ATT_CPP",    "to enter values of TENSORS_DIR_TOP1_ATT_CPP")
TENSORS_DIR_TOP1_MLP_CPP    = os.environ.get("TENSORS_DIR_TOP1_MLP_CPP",    "to enter values of TENSORS_DIR_TOP1_MLP_CPP")

TENSORS_DIR_TOP3_MLPATT_CPP = os.environ.get("TENSORS_DIR_TOP3_MLPATT_CPP", "to enter values of TENSORS_DIR_TOP3_MLPATT_CPP")
TENSORS_DIR_TOP3_ATT_CPP    = os.environ.get("TENSORS_DIR_TOP3_ATT_CPP",    "to enter values of TENSORS_DIR_TOP3_ATT_CPP")
TENSORS_DIR_TOP3_MLP_CPP    = os.environ.get("TENSORS_DIR_TOP3_MLP_CPP",    "to enter values of TENSORS_DIR_TOP3_MLP_CPP")

# Hyperparameters (placeholders). Set via env to integer values prior to running
N_env = os.environ.get("N_SAMPLES_PER_LANG", "to enter values of N_SAMPLES_PER_LANG")
NUM_K_SAMPLING_env = os.environ.get("NUM_K_SAMPLING", "to enter values of NUM_K_SAMPLING")
BATCH_SIZE_env = os.environ.get("BATCH_SIZE", "to enter values of BATCH_SIZE")
MAX_NEW_TOKENS_env = os.environ.get("MAX_NEW_TOKENS", "to enter values of MAX_NEW_TOKENS")
PREVIEW_env = os.environ.get("PREVIEW", "to enter values of PREVIEW")

# Other settings
EXEC_TIMEOUT = int(os.environ.get("EXEC_TIMEOUT", "8"))
PASSAT_K = [1, 5, 10]
PYTHON_EXEC = os.environ.get("PYTHON_EXEC", "python3")
CPP_COMPILER = os.environ.get("CPP_COMPILER", "g++")

# Logging / CSV placeholders
LOG_DIR = os.environ.get("LOG_DIR", "to enter values of LOG_DIR")
CSV_OUT = os.environ.get("CSV_OUT", "to enter values of CSV_OUT")
CSV_OUT_PY = os.environ.get("CSV_OUT_PY", "to enter values of CSV_OUT_PY")
CSV_OUT_CPP = os.environ.get("CSV_OUT_CPP", "to enter values of CSV_OUT_CPP")

# create log dir if provided (placeholder safe)
if LOG_DIR and not LOG_DIR.startswith("to enter"):
    os.makedirs(LOG_DIR, exist_ok=True)

# Utility: parse placeholder envs to ints at runtime (clear error if not set properly)
def _parse_required_int(name, val):
    try:
        return int(val)
    except Exception:
        raise RuntimeError(f"Environment variable {name} must be set to an integer (currently: {val!r}). Please set it before running.")

N = _parse_required_int("N_SAMPLES_PER_LANG", N_env)
NUM_K_SAMPLING = _parse_required_int("NUM_K_SAMPLING", NUM_K_SAMPLING_env)
BATCH_SIZE = _parse_required_int("BATCH_SIZE", BATCH_SIZE_env)
MAX_NEW_TOKENS = _parse_required_int("MAX_NEW_TOKENS", MAX_NEW_TOKENS_env)
PREVIEW = _parse_required_int("PREVIEW", PREVIEW_env)

# ensure CSV paths inside LOG_DIR
_final_csv_map = {}
for key, raw_path in [('all', CSV_OUT), ('py', CSV_OUT_PY), ('cpp', CSV_OUT_CPP)]:
    if raw_path and (not raw_path.startswith("to enter")):
        final_path = os.path.join(LOG_DIR if LOG_DIR and not LOG_DIR.startswith("to enter") else ".", os.path.basename(raw_path))
        _final_csv_map[key] = final_path
    else:
        _final_csv_map[key] = ''
if not _final_csv_map.get('py'):
    _final_csv_map['py'] = _final_csv_map.get('all', '')
if not _final_csv_map.get('cpp'):
    _final_csv_map['cpp'] = _final_csv_map.get('all', '')
for p in set([_final_csv_map.get('py'), _final_csv_map.get('cpp')]):
    if p and os.path.dirname(p):
        os.makedirs(os.path.dirname(p), exist_ok=True)
CSV_OUT = _final_csv_map.get('all') or _final_csv_map.get('py') or _final_csv_map.get('cpp')
CSV_OUT_PY = _final_csv_map.get('py')
CSV_OUT_CPP = _final_csv_map.get('cpp')

# Utilities: load pretrained and load_lora (minimal prints)
def load_pretrained(model_name, hf_token):
    tok = AutoTokenizer.from_pretrained(model_name, token=hf_token, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=hf_token,
        low_cpu_mem_usage=True
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, "bos_token", None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass
    return model.eval(), tok

def load_lora(base_model, hf_token, tensors_dir, r, alpha):
    """
    Same loading & mapping logic, but minimal stdout:
    prints only 'Loaded LoRA weights from: <tensors_dir>' when mapping applied.
    """
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=hf_token,
        low_cpu_mem_usage=True
    )
    tok = AutoTokenizer.from_pretrained(base_model, token=hf_token, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, "bos_token", None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass

    cfg = LoraConfig(
        r=r, lora_alpha=alpha,
        target_modules=["q_proj","k_proj","v_proj","o_proj","down_proj","up_proj","gate_proj"],
        task_type="CAUSAL_LM"
    )
    peft_model = get_peft_model(model, cfg)

    # read .pt files if provided
    state_raw = {}
    if tensors_dir and os.path.exists(tensors_dir):
        for root, _, files in os.walk(tensors_dir):
            for f in files:
                if f.endswith(".pt"):
                    k = f.replace(".pt", "")
                    path = os.path.join(root, f)
                    try:
                        tensor = torch.load(path, map_location="cpu")
                        state_raw[k] = tensor
                    except Exception:
                        pass
    else:
        if tensors_dir and (not os.path.exists(tensors_dir)):
            print(f"Warning: tensors_dir provided but path does not exist: {tensors_dir}")

    # mapping heuristic (same algorithm)
    target_state = peft_model.state_dict()
    target_keys = set(target_state.keys())
    mapped = {}
    suffixes = ["", ".weight", ".default.weight"]
    for raw_k, tensor in state_raw.items():
        found = False
        if raw_k in target_keys:
            mapped[raw_k] = tensor
            found = True
        else:
            for suf in suffixes:
                cand = raw_k + suf
                if cand in target_keys:
                    mapped[cand] = tensor
                    found = True
                    break
            if not found:
                for prefix in ["base_model.", "model.", ""]:
                    for suf in suffixes:
                        cand = prefix + raw_k + suf
                        if cand in target_keys:
                            mapped[cand] = tensor
                            found = True
                            break
                    if found:
                        break

    if mapped:
        peft_model.load_state_dict(mapped, strict=False)
        print(f"Loaded LoRA weights from: {tensors_dir}")
    else:
        if tensors_dir and os.path.exists(tensors_dir):
            print(f"Attempted load from {tensors_dir} (no tensors mapped).")

    peft_model.eval()
    tok.pad_token = tok.eos_token
    return peft_model, tok

# Prompt / cleaning / generation helpers (unchanged logic)
def build_prompt_for_lang(item, language="python"):
    prompt_or_instruction = item.get("instruction") or item.get("prompt") or ""
    signature = item.get("signature", "")
    declaration = item.get("declaration", "")
    docstring = item.get("docstring", "")
    example_test = item.get("example_test", "")

    if language == "python":
        template = (
            f"# Problem:\n{prompt_or_instruction}\n\n"
            f"# Signature:\n{signature}\n\n"
            f"# Docstring:\n{docstring}\n\n"
            f"# Examples:\n{example_test}\n\n"
            "Write the complete Python function implementation only.\n"
            "Output only valid Python code for the function (no explanation, no tests, no surrounding markdown).\n"
            "Make sure the function name and signature match the signature above.\n\n"
            "Implementation:\n"
        )
    else:
        template = (
            f"// Problem:\n{prompt_or_instruction}\n\n"
            f"// Declaration:\n{declaration}\n\n"
            f"// Docstring / Notes:\n{docstring}\n\n"
            f"// Examples:\n{example_test}\n\n"
            "Write the C++ implementation only (no explanation, no tests, no surrounding markdown).\n"
            "Include necessary #include lines if needed. Ensure function name and signature match the declaration above.\n\n"
            "Implementation:\n"
        )
    return template

def clean_code_output(out):
    if out is None:
        return ""
    s = str(out).strip()
    fence_match = re.search(r"```(?:\w+)?\n(.*?)\n```", s, flags=re.DOTALL)
    if fence_match:
        s = fence_match.group(1).strip()
    else:
        s = re.sub(r"^```(?:\w+)?\s*", "", s)
        s = re.sub(r"\s*```$", "", s)
    return s.strip()

def extract_def_block_python(code, signature):
    if not signature:
        return code
    sig_name = signature.split("(")[0].strip()
    m = re.search(rf"(def\s+{re.escape(sig_name)}\s*\(.*)", code, flags=re.DOTALL)
    if not m:
        return code
    start = m.start()
    return code[start:].strip()

def _pin_and_move_inputs(inputs, device):
    inputs_pinned = {}
    for k, v in inputs.items():
        try:
            v_p = v.pin_memory()
            inputs_pinned[k] = v_p.to(device, non_blocking=True)
        except Exception:
            inputs_pinned[k] = v.to(device)
    return inputs_pinned

def generate_greedy(model, tok, prompt, max_new_tokens=MAX_NEW_TOKENS):
    inputs_cpu = tok(prompt, return_tensors="pt")
    inputs = _pin_and_move_inputs(inputs_cpu, model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tok.eos_token_id
        )
    seq = out[0]
    in_len = inputs["input_ids"].shape[-1]
    cont_ids = seq[in_len:]
    return tok.decode(cont_ids, skip_special_tokens=True).strip()

def generate_samples(model, tok, prompt, num_samples=8, batch_size=2, max_new_tokens=512, temperature=0.7, top_p=0.95):
    inputs_cpu = tok(prompt, return_tensors="pt")
    inputs = _pin_and_move_inputs(inputs_cpu, model.device)
    candidates = []
    num_loops = math.ceil(num_samples / batch_size)
    produced = 0
    with torch.no_grad():
        gen_conf = GenerationConfig(
            max_new_tokens = max_new_tokens,
            do_sample = True,
            temperature = temperature,
            top_p = top_p
        )
        for _ in range(num_loops):
            this_bs = min(batch_size, num_samples - produced)
            outs = model.generate(
                **inputs,
                num_return_sequences = this_bs,
                generation_config = gen_conf,
                pad_token_id = tok.eos_token_id,
            )
            for j in range(outs.shape[0]):
                seq = outs[j]
                in_len = inputs["input_ids"].shape[-1]
                cont_ids = seq[in_len:]
                cont = tok.decode(cont_ids, skip_special_tokens=True).strip()
                candidates.append(cont)
            produced += this_bs
    return candidates

def write_and_run_python_solution(solution_code, test_code, workdir, timeout=EXEC_TIMEOUT):
    file_path = os.path.join(workdir, "solution.py")
    combined = solution_code.strip() + "\n\n" + test_code.strip() + "\n"
    with open(file_path, "w", encoding="utf-8") as f:
        f.write(combined)
    try:
        proc = subprocess.run([PYTHON_EXEC, file_path],
                              capture_output=True, text=True, timeout=timeout)
        passed = proc.returncode == 0
        return passed, proc.stdout, proc.stderr, proc.returncode
    except subprocess.TimeoutExpired:
        return False, "", f"TimeoutExpired after {timeout}s", -1
    except Exception as e:
        return False, "", f"Exception while running: {e}", -2

def write_compile_and_run_cpp(solution_code, test_code, workdir, timeout=EXEC_TIMEOUT):
    src_path = os.path.join(workdir, "solution.cpp")
    exe_path = os.path.join(workdir, "solution_exec")
    combined = solution_code.strip() + "\n\n" + test_code.strip() + "\n"
    with open(src_path, "w", encoding="utf-8") as f:
        f.write(combined)
    try:
        comp = subprocess.run([CPP_COMPILER, src_path, "-std=c++17", "-O2", "-o", exe_path],
                              capture_output=True, text=True, timeout=timeout)
        if comp.returncode != 0:
            return False, comp.stdout, comp.stderr, "", "", comp.returncode
        proc = subprocess.run([exe_path], capture_output=True, text=True, timeout=timeout)
        passed = proc.returncode == 0
        return passed, comp.stdout, comp.stderr, proc.stdout, proc.stderr, proc.returncode
    except subprocess.TimeoutExpired:
        return False, "", f"TimeoutExpired after {timeout}s", "", f"TimeoutExpired after {timeout}s", -1
    except Exception as e:
        return False, "", f"Exception while compiling/running: {e}", "", str(e), -2

def _worker_run_candidate(args):
    lang, cand_code, test_code, timeout = args
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            if lang == "python":
                passed, sout, serr, exit_code = write_and_run_python_solution(cand_code, test_code, tmpdir, timeout=timeout)
                logs = {"stdout": sout, "stderr": serr, "exit_code": exit_code}
                return passed, logs
            else:
                passed, comp_out, comp_err, run_out, run_err, exit_code = write_compile_and_run_cpp(cand_code, test_code, tmpdir, timeout=timeout)
                logs = {"compile_stdout": comp_out, "compile_stderr": comp_err, "run_stdout": run_out, "run_stderr": run_err, "exit_code": exit_code}
                return passed, logs
    except Exception as e:
        return False, {"exception": str(e)}

def pass_at_k(n, c, k):
    if n < k:
        return float('nan')
    if c == 0:
        return 0.0
    return 1.0 - comb(n - c, k) / comb(n, k)

# Evaluate one loaded model over dataset ds (returns aggregated pass@k and accuracy counts)
def evaluate_loaded_model(model, tok, ds, lang, model_label, pool, log_fh=None):
    total = len(ds)
    problems_with_success = 0
    passatk_sum = {k: 0.0 for k in PASSAT_K}
    passatk_valid_counts = {k: 0 for k in PASSAT_K}

    rows_for_model = []

    pbar = tqdm(total=total, desc=f"{model_label}", unit="it")
    try:
        for i, ex in enumerate(ds):
            prompt_text = build_prompt_for_lang(ex, language=lang)
            signature = ex.get("signature", "")
            declaration = ex.get("declaration", "")
            entry_point = ex.get("entry_point", "")
            test_code = ex.get("test", "")

            if NUM_K_SAMPLING > 1:
                raw_cands = generate_samples(model, tok, prompt_text,
                                             num_samples=NUM_K_SAMPLING,
                                             batch_size=BATCH_SIZE,
                                             max_new_tokens=MAX_NEW_TOKENS,
                                             temperature=0.7, top_p=0.95)
            else:
                raw = generate_greedy(model, tok, prompt_text, max_new_tokens=MAX_NEW_TOKENS)
                raw_cands = [raw]

            cleaned_cands = []
            for rc in raw_cands:
                c = clean_code_output(rc)
                if lang == "python":
                    c = extract_def_block_python(c, signature)
                cleaned_cands.append(c)

            job_args = [(lang, cand_code, test_code, EXEC_TIMEOUT) for cand_code in cleaned_cands]
            try:
                results = pool.map(_worker_run_candidate, job_args)
            except Exception:
                results = []
                for args in job_args:
                    results.append(_worker_run_candidate(args))

            passes_per_candidate = []
            logs_per_candidate = []
            for passed, logs in results:
                passes_per_candidate.append(bool(passed))
                logs_per_candidate.append(logs)

            c = sum(1 for p in passes_per_candidate if p)
            n = len(passes_per_candidate)
            if c > 0:
                problems_with_success += 1

            passatk_per_k = {}
            for k in PASSAT_K:
                pk = pass_at_k(n, c, k) if n > 0 else float('nan')
                passatk_per_k[k] = pk
                if not math.isnan(pk):
                    passatk_sum[k] += pk
                    passatk_valid_counts[k] += 1

            row = {
                "lang": lang,
                "index": i,
                "model": model_label,
                "entry_point": entry_point,
                "signature": signature,
                "declaration": declaration,
                "n": n,
                "c": c,
                "pass@1": passatk_per_k.get(1),
                "pass@5": passatk_per_k.get(5),
                "pass@10": passatk_per_k.get(10),
                "candidates": "\n===CAND===\n".join(cleaned_cands),
                "logs_sample": str(logs_per_candidate[:3])
            }
            rows_for_model.append(row)

            if i < PREVIEW:
                per_sample_lines = []
                per_sample_lines.append("\n" + "="*60)
                per_sample_lines.append(f"[{lang}] Problem {i} | model: {model_label} | n={n} c={c}")
                prompt_preview_src = (ex.get("instruction") or ex.get("prompt") or "")
                lines = [ln for ln in prompt_preview_src.splitlines() if ln.strip()]
                prompt_first = lines[0][:200] if lines else ""
                per_sample_lines.append("Prompt (first line): " + prompt_first)
                per_sample_lines.append("pass@1/5/10: {} {} {}".format(passatk_per_k.get(1), passatk_per_k.get(5), passatk_per_k.get(10)))
                if cleaned_cands:
                    top_candidate = "\n".join(cleaned_cands[0].splitlines()[:12])
                else:
                    top_candidate = ""
                per_sample_lines.append("Top candidate (first 12 lines):\n" + top_candidate)
                per_sample_lines.append("="*60)
                per_sample_text = "\n".join(per_sample_lines)
                print(per_sample_text)

                if log_fh:
                    try:
                        log_fh.write(per_sample_text + "\n")
                    except Exception:
                        pass

            pbar.update(1)
    finally:
        pbar.close()

    aggregated = {
        "problems": total,
        "passed_problems": problems_with_success,
        "accuracy": problems_with_success / total if total > 0 else float('nan'),
        "avg_passatk": {
            k: (passatk_sum[k] / passatk_valid_counts[k]) if passatk_valid_counts[k] > 0 else float('nan') for k in PASSAT_K
        },
        "rows": rows_for_model
    }
    return aggregated

# Orchestration: evaluate ten models sequentially for a language
def evaluate_ten_models_for_lang(lang, model_dir_map, log_fh=None, csv_out_path=None):
    """
    model_dir_map: dict mapping the 9 LoRA labels to dirs; we also evaluate 'pretrained' baseline.
    Expected labels for LO-RA keys:
      'full_mlpatt','full_att','full_mlp',
      'top1_mlpatt','top1_att','top1_mlp',
      'top3_mlpatt','top3_att','top3_mlp'
    We will evaluate in this order:
      ['pretrained', 'full_mlpatt','full_att','full_mlp','top1_mlpatt','top1_att','top1_mlp','top3_mlpatt','top3_att','top3_mlp']
    """
    print(f"\n=== Running HumanEvalpack ({lang}) N={N}, samples per problem={NUM_K_SAMPLING} ===\n")
    ds = load_dataset("bigcode/humanevalpack", lang, split=f"test[:{N}]")
    print(f"Loaded {len(ds)} items for {lang}")

    # models order
    models_order = [
        "pretrained",
        "full_mlpatt","full_att","full_mlp",
        "top1_mlpatt","top1_att","top1_mlp",
        "top3_mlpatt","top3_att","top3_mlp"
    ]

    # prepare pool for candidate execution
    cpu_count = multiprocessing.cpu_count() or 1
    pool_size = max(1, min(8, max(1, cpu_count // 2)))
    pool = multiprocessing.Pool(processes=pool_size)

    results_per_model = {}
    all_rows = []

    for model_label in models_order:
        if model_label == "pretrained":
            print(f"\n--- Evaluating model variant: {model_label} (pretrained base) ---\n")
            model_obj, tok_obj = load_pretrained(MODEL_NAME, HF_TOKEN)
            agg = evaluate_loaded_model(model_obj, tok_obj, ds, lang, model_label, pool, log_fh=log_fh)
            results_per_model[model_label] = agg
            all_rows.extend(agg["rows"])
            # unload
            try:
                del model_obj, tok_obj
                torch.cuda.empty_cache()
            except Exception:
                pass
        else:
            tensors_dir = model_dir_map.get(model_label, "")
            print(f"\n--- Evaluating model variant: {model_label} (tensors_dir={tensors_dir}) ---\n")
            if not tensors_dir or str(tensors_dir).startswith("to enter") or (not os.path.exists(tensors_dir)):
                print(f"Skipping {model_label}: tensors dir missing or placeholder ({tensors_dir}). Recording nan metrics.")
                results_per_model[model_label] = {
                    "problems": len(ds),
                    "passed_problems": 0,
                    "accuracy": float('nan'),
                    "avg_passatk": {k: float('nan') for k in PASSAT_K},
                    "rows": []
                }
                continue

            model_obj, tok_obj = load_lora(MODEL_NAME, HF_TOKEN, tensors_dir, r=16, alpha=32)
            agg = evaluate_loaded_model(model_obj, tok_obj, ds, lang, model_label, pool, log_fh=log_fh)
            results_per_model[model_label] = agg
            all_rows.extend(agg["rows"])
            # unload
            try:
                del model_obj, tok_obj
                torch.cuda.empty_cache()
            except Exception:
                pass

    # close pool
    pool.close()
    pool.join()

    # write CSV per-problem rows (append)
    out_path = csv_out_path or CSV_OUT
    if out_path and all_rows:
        try:
            file_exists = os.path.exists(out_path)
            os.makedirs(os.path.dirname(out_path), exist_ok=True)
            with open(out_path, "a", newline="", encoding="utf-8") as fh:
                writer = csv.DictWriter(fh, fieldnames=list(all_rows[0].keys()))
                if not file_exists:
                    writer.writeheader()
                for r in all_rows:
                    writer.writerow(r)
            print(f"Appended {len(all_rows)} rows to CSV: {out_path}")
        except Exception as e:
            print(f"Could not write CSV summary to {out_path}: {e}")

    # aggregated table for printing + PNG
    print(f"\n=== Final aggregated results for {lang} ===\n")
    headers = []
    vals_acc = []
    vals_pass1 = []
    vals_pass5 = []
    vals_pass10 = []
    for m in models_order:
        headers.append(m)
        agg = results_per_model.get(m, {})
        acc = agg.get("accuracy", float('nan'))
        vals_acc.append(acc if acc is not None else float('nan'))
        avg_pk = agg.get("avg_passatk", {1:float('nan'),5:float('nan'),10:float('nan')})
        vals_pass1.append(avg_pk.get(1, float('nan')))
        vals_pass5.append(avg_pk.get(5, float('nan')))
        vals_pass10.append(avg_pk.get(10, float('nan')))

        print(f"{m}: passed_problems = {agg.get('passed_problems', 'n/a')} / {agg.get('problems', 'n/a')}  acc = {acc}")

    # PNG table
    try:
        import matplotlib.pyplot as plt
        def safe_fmt(x):
            try:
                return f"{x:.4f}"
            except Exception:
                return "nan"
        row_pass1 = [ safe_fmt(x) for x in vals_pass1 ]
        row_pass5 = [ safe_fmt(x) for x in vals_pass5 ]
        row_pass10= [ safe_fmt(x) for x in vals_pass10 ]
        row_acc  = [ safe_fmt(x) for x in vals_acc ]

        fig_w = max(8, 0.6 * len(headers))
        fig, ax = plt.subplots(figsize=(fig_w, 2.8))
        ax.axis('off')
        table_data = [
            ["pass@1"] + row_pass1,
            ["pass@5"] + row_pass5,
            ["pass@10"]+ row_pass10,
            ["accuracy"] + row_acc
        ]
        col_labels = ["metric"] + headers
        table = ax.table(cellText=table_data, colLabels=col_labels, loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(8)
        table.scale(1, 1.8)
        png_path = os.path.join(LOG_DIR if LOG_DIR and not LOG_DIR.startswith("to enter") else ".", f"{lang}_ten_models_accuracy_table.png")
        plt.tight_layout()
        plt.savefig(png_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print("Saved accuracy PNG table to:", png_path)
    except Exception as e:
        print("Could not create/save PNG accuracy table:", e)

    return results_per_model

# Entrypoint
if __name__ == "__main__":
    has_gpp = shutil.which(CPP_COMPILER) is not None
    print("g++ available:", has_gpp, " (", CPP_COMPILER, ")")
    print("LOG_DIR:", LOG_DIR)
    print("CSV_OUT (all):", CSV_OUT)
    print("CSV_OUT_PY:", CSV_OUT_PY)
    print("CSV_OUT_CPP:", CSV_OUT_CPP)

    # load pretrained base once for tokenizer sanity - but we will load pretrained for evaluation inside the per-lang flow
    # Prepare python model_dir_map
    py_model_dirs = {
        "full_mlpatt": TENSORS_DIR_FULL_MLPATT_PY,
        "full_att":    TENSORS_DIR_FULL_ATT_PY,
        "full_mlp":    TENSORS_DIR_FULL_MLP_PY,
        "top1_mlpatt": TENSORS_DIR_TOP1_MLPATT_PY,
        "top1_att":    TENSORS_DIR_TOP1_ATT_PY,
        "top1_mlp":    TENSORS_DIR_TOP1_MLP_PY,
        "top3_mlpatt": TENSORS_DIR_TOP3_MLPATT_PY,
        "top3_att":    TENSORS_DIR_TOP3_ATT_PY,
        "top3_mlp":    TENSORS_DIR_TOP3_MLP_PY
    }

    cpp_model_dirs = {
        "full_mlpatt": TENSORS_DIR_FULL_MLPATT_CPP,
        "full_att":    TENSORS_DIR_FULL_ATT_CPP,
        "full_mlp":    TENSORS_DIR_FULL_MLP_CPP,
        "top1_mlpatt": TENSORS_DIR_TOP1_MLPATT_CPP,
        "top1_att":    TENSORS_DIR_TOP1_ATT_CPP,
        "top1_mlp":    TENSORS_DIR_TOP1_MLP_CPP,
        "top3_mlpatt": TENSORS_DIR_TOP3_MLPATT_CPP,
        "top3_att":    TENSORS_DIR_TOP3_ATT_CPP,
        "top3_mlp":    TENSORS_DIR_TOP3_MLP_CPP
    }

    # Python evaluation
    print("\n=== Python evaluation (10 models sequentially) ===")
    dataset_log_py = os.path.join(LOG_DIR, "python_detailed_log.txt") if LOG_DIR and not LOG_DIR.startswith("to enter") else None
    try:
        log_fh_py = open(dataset_log_py, "w", encoding="utf-8") if dataset_log_py else None
    except Exception as e:
        print("Could not open python log file for writing:", e)
        log_fh_py = None
    py_results = evaluate_ten_models_for_lang("python", py_model_dirs, log_fh=log_fh_py, csv_out_path=CSV_OUT_PY or CSV_OUT)
    if log_fh_py:
        log_fh_py.close()

    # C++ evaluation
    if has_gpp:
        print("\n=== C++ evaluation (10 models sequentially) ===")
        dataset_log_cpp = os.path.join(LOG_DIR, "cpp_detailed_log.txt") if LOG_DIR and not LOG_DIR.startswith("to enter") else None
        try:
            log_fh_cpp = open(dataset_log_cpp, "w", encoding="utf-8") if dataset_log_cpp else None
        except Exception as e:
            print("Could not open cpp log file for writing:", e)
            log_fh_cpp = None
        cpp_results = evaluate_ten_models_for_lang("cpp", cpp_model_dirs, log_fh=log_fh_cpp, csv_out_path=CSV_OUT_CPP or CSV_OUT)
        if log_fh_cpp:
            log_fh_cpp.close()
    else:
        print("g++ not available; skipping C++ evaluation.")

    print("\nAll done. CSVs (if any) appended to:")
    for p in set(filter(None, [CSV_OUT, CSV_OUT_PY, CSV_OUT_CPP])):
        print(" -", p)
