import os
import re
import torch
import math
import collections
import time
import csv
import hashlib

# cuDNN tuning for repeated shapes / faster kernels
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from tqdm import tqdm

# -------------------------------
# Config / env (placeholders)
# -------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN", "to enter values of HF_TOKEN")
MODEL_NAME = os.environ.get("MODEL_NAME", "to enter values of MODEL_NAME")

# LoRA directories (placeholders) - 9 LoRA dirs (full / top1 / top3) x (mlpatt, att, mlp)
TENSORS_DIR_FULL_MLPATT = os.environ.get("TENSORS_DIR_FULL_MLPATT", "to enter values of TENSORS_DIR_FULL_MLPATT")
TENSORS_DIR_FULL_ATT    = os.environ.get("TENSORS_DIR_FULL_ATT",    "to enter values of TENSORS_DIR_FULL_ATT")
TENSORS_DIR_FULL_MLP    = os.environ.get("TENSORS_DIR_FULL_MLP",    "to enter values of TENSORS_DIR_FULL_MLP")

TENSORS_DIR_TOP1_MLPATT = os.environ.get("TENSORS_DIR_TOP1_MLPATT", "to enter values of TENSORS_DIR_TOP1_MLPATT")
TENSORS_DIR_TOP1_ATT    = os.environ.get("TENSORS_DIR_TOP1_ATT",    "to enter values of TENSORS_DIR_TOP1_ATT")
TENSORS_DIR_TOP1_MLP    = os.environ.get("TENSORS_DIR_TOP1_MLP",    "to enter values of TENSORS_DIR_TOP1_MLP")

TENSORS_DIR_TOP3_MLPATT = os.environ.get("TENSORS_DIR_TOP3_MLPATT", "to enter values of TENSORS_DIR_TOP3_MLPATT")
TENSORS_DIR_TOP3_ATT    = os.environ.get("TENSORS_DIR_TOP3_ATT",    "to enter values of TENSORS_DIR_TOP3_ATT")
TENSORS_DIR_TOP3_MLP    = os.environ.get("TENSORS_DIR_TOP3_MLP",    "to enter values of TENSORS_DIR_TOP3_MLP")

# CSV / logging placeholders
CSV_OUT = os.environ.get("CSV_OUT", "to enter values of CSV_OUT")
LOG_DIR = os.environ.get("LOG_DIR", "to enter values of LOG_DIR")

# Hyperparameters placeholders (must be set to integers via env)
def _parse_required_int_env(name):
    v = os.environ.get(name, f"to enter values of {name}")
    if isinstance(v, str) and v.strip().startswith("to enter"):
        raise RuntimeError(f"Environment variable {name} must be set to an integer. Current value is a placeholder: {v!r}")
    try:
        return int(v)
    except Exception:
        raise RuntimeError(f"Environment variable {name} must be an integer. Got: {v!r}")

N = _parse_required_int_env("N_SAMPLES")
NUM_K_SAMPLING = _parse_required_int_env("NUM_K_SAMPLING")
BATCH_SIZE = _parse_required_int_env("BATCH_SIZE")
PREVIEW = _parse_required_int_env("PREVIEW")

# Execution settings
EXEC_TIMEOUT = int(os.environ.get("EXEC_TIMEOUT", "8"))

# Ensure LOG_DIR exists if provided
if LOG_DIR and not LOG_DIR.startswith("to enter"):
    os.makedirs(LOG_DIR, exist_ok=True)

# -------------------------------
# Input cache (prompt hash -> tokenized tensors moved to device)
# -------------------------------
_prompt_cache = {}  # keys: (prompt_hash, device_str) -> dict of tensors

def _prompt_hash(s):
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def _get_cached_inputs(tok, prompt, device):
    key = (_prompt_hash(prompt), str(device))
    if key in _prompt_cache:
        return _prompt_cache[key]
    # tokenize once (CPU)
    inputs = tok(prompt, return_tensors="pt")
    # Attempt to pin CPU tensors before moving to device for faster host->device transfer
    inputs_on_dev = {}
    for k, v in inputs.items():
        try:
            v_p = v.pin_memory()
            inputs_on_dev[k] = v_p.to(device, non_blocking=True)
        except Exception:
            try:
                inputs_on_dev[k] = v.to(device, non_blocking=True)
            except Exception:
                inputs_on_dev[k] = v.to(device)
    _prompt_cache[key] = inputs_on_dev
    return inputs_on_dev

def _clear_prompt_cache():
    global _prompt_cache
    _prompt_cache = {}

# -------------------------------
# Load Pretrained (kept - tokenizer / sanity checks)
# -------------------------------
def load_pretrained(model_name, hf_token):
    tok = AutoTokenizer.from_pretrained(model_name, token=hf_token, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=hf_token,
        low_cpu_mem_usage=True
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, "bos_token", None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass
    return model.eval(), tok

# -------------------------------
# Load LoRA (minimal prints; same mapping logic)
# -------------------------------
def load_lora(base_model, hf_token, tensors_dir, r, alpha):
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=hf_token,
        low_cpu_mem_usage=True
    )
    tok = AutoTokenizer.from_pretrained(base_model, token=hf_token, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, "bos_token", None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass

    cfg = LoraConfig(
        r=r, lora_alpha=alpha,
        target_modules=["q_proj","k_proj","v_proj","o_proj","down_proj","up_proj","gate_proj"],
        task_type="CAUSAL_LM"
    )
    peft_model = get_peft_model(model, cfg)

    # read .pt files if provided
    state_raw = {}
    if tensors_dir and os.path.exists(tensors_dir):
        for root, _, files in os.walk(tensors_dir):
            for f in files:
                if f.endswith(".pt"):
                    k = f.replace(".pt", "")
                    path = os.path.join(root, f)
                    try:
                        tensor = torch.load(path, map_location="cpu")
                        state_raw[k] = tensor
                    except Exception:
                        pass
    else:
        if tensors_dir and (not os.path.exists(tensors_dir)):
            print(f"Warning: tensors_dir provided but path does not exist: {tensors_dir}")
        else:
            # no tensors dir provided -> return PEFT wrapper unchanged
            pass

    # mapping heuristic (suppressed verbose diagnostics)
    target_state = peft_model.state_dict()
    target_keys = set(target_state.keys())
    mapped = {}
    suffixes = ["", ".weight", ".default.weight"]
    for raw_k, tensor in state_raw.items():
        found = False
        if raw_k in target_keys:
            mapped[raw_k] = tensor
            found = True
        else:
            for suf in suffixes:
                cand = raw_k + suf
                if cand in target_keys:
                    mapped[cand] = tensor
                    found = True
                    break
            if not found:
                for prefix in ["base_model.", "model.", ""]:
                    for suf in suffixes:
                        cand = prefix + raw_k + suf
                        if cand in target_keys:
                            mapped[cand] = tensor
                            found = True
                            break
                    if found:
                        break

    if mapped:
        peft_model.load_state_dict(mapped, strict=False)
        print(f"Loaded LoRA weights from: {tensors_dir}")
    else:
        if tensors_dir and os.path.exists(tensors_dir):
            print(f"Attempted load from {tensors_dir} (no tensors mapped).")

    peft_model.eval()
    tok.pad_token = tok.eos_token
    return peft_model, tok

# -------------------------------
# GSM8K prefix builder (kept unchanged)
# -------------------------------
_GSM8K_PREFIX = None

def _build_gsm8k_8shot_prefix():
    ds_train = load_dataset("openai/gsm8k", "main", split="test[:8]")
    demos = []
    for ex in ds_train:
        q = ex["question"].strip()
        a = ex["answer"].strip()
        demos.append(f"Q: {q}\nA: {a}")
    return "\n\n".join(demos)

def _get_prefix():
    global _GSM8K_PREFIX
    if _GSM8K_PREFIX is None:
        _GSM8K_PREFIX = _build_gsm8k_8shot_prefix()
    return _GSM8K_PREFIX

# -------------------------------
# Generation helpers (use cached inputs + inference_mode)
# -------------------------------
def generate_greedy(model, tok, prompt, max_new_tokens=1024):
    inputs = _get_cached_inputs(tok, prompt, model.device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tok.eos_token_id
        )
    seq = out[0]
    in_len = inputs["input_ids"].shape[-1]
    cont_ids = seq[in_len:]
    return tok.decode(cont_ids, skip_special_tokens=True).strip()

def generate_samples(model, tok, prompt, num_samples=8, batch_size=2, max_new_tokens=512, temperature=0.7, top_p=0.95):
    inputs = _get_cached_inputs(tok, prompt, model.device)
    candidates = []
    num_loops = math.ceil(num_samples / batch_size)
    produced = 0
    with torch.inference_mode():
        for _ in range(num_loops):
            this_bs = min(batch_size, num_samples - produced)
            gen_conf = GenerationConfig(
                max_new_tokens = max_new_tokens,
                do_sample = True,
                temperature = temperature,
                top_p = top_p,
                repetition_penalty = 1.1
            )
            outs = model.generate(
                **inputs,
                num_return_sequences = this_bs,
                generation_config = gen_conf,
                pad_token_id = tok.eos_token_id,
            )
            for j in range(outs.shape[0]):
                seq = outs[j]
                in_len = inputs["input_ids"].shape[-1]
                cont_ids = seq[in_len:]
                cont = tok.decode(cont_ids, skip_special_tokens=True).strip()
                lines = [ln for ln in cont.splitlines() if ln.strip() != ""]
                candidates.append(lines[0].strip() if lines else "")
            produced += this_bs
    return candidates

def generate_answer_self_consistency(model, tok, question,
                                     num_samples=16,
                                     batch_size=4,
                                     max_new_tokens=1024,
                                     temperature=0.7,
                                     top_p=0.90,
                                     no_repeat_ngram_size=3,
                                     repetition_penalty=1.1):
    prefix = _get_prefix()
    prompt = (
        prefix
        + f"\n\nQ: {question.strip()}\n"
        "A: Let's reason step by step.\n"
        "At the end, give the final numeric answer on its own line in this exact format:\n"
        "#### <number>\n"
        "Answer:")
    candidates = generate_samples(
        model, tok, prompt,
        num_samples=num_samples,
        batch_size=batch_size,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p
    )
    numbers = [extract_pred_number_from_text(c) for c in candidates]
    num_counts = collections.Counter(n for n in numbers if n is not None)
    chosen_num = None
    chosen_continuation = None
    if num_counts:
        chosen_num, _ = num_counts.most_common(1)[0]
        for cont, n in zip(candidates, numbers):
            if n == chosen_num:
                chosen_continuation = cont
                break
    else:
        text_counts = collections.Counter(candidates)
        if text_counts:
            chosen_continuation, _ = text_counts.most_common(1)[0]
            chosen_num = extract_pred_number_from_text(chosen_continuation)
        else:
            chosen_continuation = ""
            chosen_num = None
    return chosen_continuation, candidates, chosen_num

# -------------------------------
# Gold/pred extraction (unchanged)
# -------------------------------
def extract_gold(answer_str):
    m = re.findall(r"####\s*\$?\{?\s*([-+]?\d[\d,]*\.?\d*)\s*\}?\$?", answer_str)
    if not m:
        return None
    s = m[-1].replace(",", "")
    try:
        return int(s) if s.replace("-", "").isdigit() else float(s)
    except:
        return None

def extract_pred_number_from_text(text):
    if not text:
        return None
    m = re.findall(r"####\s*\$?\{?\s*([-+]?\d[\d,]*\.?\d*)\s*\}?\$?", text)
    if m:
        s = m[0].replace(",", "")
        try:
            return int(s) if s.replace("-", "").isdigit() else float(s)
        except:
            pass
    m2 = re.findall(r"[\${\s]*([-+]?\d[\d,]*\.?\d*)[\}\$]*", text)
    if m2:
        s = m2[-1].replace(",", "")
        try:
            return int(s) if s.replace("-", "").isdigit() else float(s)
        except:
            pass
    return None

def evaluate_prediction(pred_text, gold):
    pred_num = extract_pred_number_from_text(pred_text)
    return pred_num == gold, pred_num

# -------------------------------
# Main: sequential LoRA evaluation across 10 models
# -------------------------------
if __name__ == "__main__":
    HF_TOKEN = os.environ.get("HF_TOKEN", HF_TOKEN)
    MODEL_NAME = os.environ.get("MODEL_NAME", MODEL_NAME)

    # ensure LOG_DIR exists if provided
    if LOG_DIR and not LOG_DIR.startswith("to enter"):
        os.makedirs(LOG_DIR, exist_ok=True)

    # ensure CSV_OUT parent dir if provided
    if CSV_OUT and not CSV_OUT.startswith("to enter"):
        os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)

    dataset_name_for_log = "GSM8K"
    log_file_path = os.path.join(LOG_DIR, f"{dataset_name_for_log}_detailed_log.txt") if LOG_DIR and not LOG_DIR.startswith("to enter") else None
    try:
        log_fh = open(log_file_path, "w", encoding="utf-8") if log_file_path else None
    except Exception as e:
        print("Could not open log file for writing:", e)
        log_fh = None

    ds = load_dataset("openai/gsm8k", "main", split=f"test[:{N}]")
    total = len(ds)
    print(f"\nLoaded {total} samples from GSM8K (self-consistency={NUM_K_SAMPLING}, batch_size={BATCH_SIZE})\n")

    # 10 models ordering
    models_order = [
        "pretrained",
        "full_mlpatt","full_att","full_mlp",
        "top1_mlpatt","top1_att","top1_mlp",
        "top3_mlpatt","top3_att","top3_mlp"
    ]

    tensors_map = {
        "full_mlpatt": TENSORS_DIR_FULL_MLPATT,
        "full_att":    TENSORS_DIR_FULL_ATT,
        "full_mlp":    TENSORS_DIR_FULL_MLP,
        "top1_mlpatt": TENSORS_DIR_TOP1_MLPATT,
        "top1_att":    TENSORS_DIR_TOP1_ATT,
        "top1_mlp":    TENSORS_DIR_TOP1_MLP,
        "top3_mlpatt": TENSORS_DIR_TOP3_MLPATT,
        "top3_att":    TENSORS_DIR_TOP3_ATT,
        "top3_mlp":    TENSORS_DIR_TOP3_MLP
    }

    results_counts = {m: 0 for m in models_order}

    # sequential loop
    for model_label in models_order:
        if model_label == "pretrained":
            print(f"\n--- Evaluating pretrained model ---")
            model, tok = load_pretrained(MODEL_NAME, HF_TOKEN)
        else:
            tensors_dir = tensors_map.get(model_label, "")
            print(f"\n--- Loading LoRA adapter '{model_label}' from: {tensors_dir} ---")
            if not tensors_dir or str(tensors_dir).startswith("to enter") or (not os.path.exists(tensors_dir)):
                print(f"Skipping {model_label}: tensors dir missing or placeholder ({tensors_dir}).")
                continue
            model, tok = load_lora(MODEL_NAME, HF_TOKEN, tensors_dir, r=16, alpha=32)

        correct = 0
        pbar = tqdm(total=total, desc=f"{model_label}", unit="it")
        try:
            for i, ex in enumerate(ds):
                q = ex["question"].strip()
                gold = extract_gold(ex["answer"].strip())

                if NUM_K_SAMPLING > 1:
                    cont, cand_list, num_pred = generate_answer_self_consistency(
                        model, tok, q,
                        num_samples=NUM_K_SAMPLING,
                        batch_size=BATCH_SIZE,
                        max_new_tokens=1024,
                        temperature=0.7,
                        top_p=0.90
                    )
                    chosen = cont
                    pred_num = num_pred
                else:
                    chosen = generate_greedy(model, tok, q, max_new_tokens=1024)
                    _, pred_num = evaluate_prediction(chosen, gold)

                is_corr_flag = (pred_num == gold)
                correct += int(is_corr_flag)

                if i < PREVIEW:
                    per_sample_lines = []
                    per_sample_lines.append("\n" + "="*80)
                    per_sample_lines.append(f"[{model_label}] Index {i} | Question:")
                    per_sample_lines.append(q)
                    per_sample_lines.append(f"Ground Truth: {gold}\n")
                    per_sample_lines.append(f"\n--- Model {model_label} ---")
                    if NUM_K_SAMPLING > 1:
                        per_sample_lines.append("Model output (chosen): " + str(chosen))
                        per_sample_lines.append("Candidates: " + str(cand_list))
                    else:
                        per_sample_lines.append(chosen)
                    per_sample_lines.append("Interpreted Pred: {} | Correct? {}".format(pred_num, is_corr_flag))
                    per_sample_lines.append("="*80)
                    per_sample_text = "\n".join(per_sample_lines)
                    print(per_sample_text)
                    if log_fh:
                        try:
                            log_fh.write(per_sample_text + "\n")
                            if (i % 50) == 0:
                                log_fh.flush()
                        except Exception:
                            pass

                pbar.update(1)
        finally:
            pbar.close()

        results_counts[model_label] = correct
        print(f"\nModel '{model_label}' done: {correct} / {total} = {correct/total if total>0 else 0.0:.3f}")

        # unload and clear prompt cache
        try:
            del model, tok
            torch.cuda.empty_cache()
            _clear_prompt_cache()
        except Exception:
            pass

    if log_fh:
        try:
            log_fh.close()
        except Exception:
            pass

    # final aggregated print
    print("\n\nFINAL AGGREGATED ACCURACY (ALL MODELS) — Dataset: {}  N = {}\n".format(dataset_name_for_log, total))
    headers = models_order[:]
    vals = []
    for m in headers:
        correct = results_counts.get(m, 0)
        acc = correct / total if total > 0 else float('nan')
        vals.append(acc)
        print(f"{m}: {correct} / {total} = {acc:.3f}")

    # append CSV
    try:
        if CSV_OUT and not CSV_OUT.startswith("to enter"):
            file_exists = os.path.exists(CSV_OUT)
            os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)
            with open(CSV_OUT, "a", newline="", encoding="utf-8") as fh:
                writer = csv.writer(fh)
                if not file_exists:
                    header_row = ["dataset", "N"] + [f"{h}_correct" for h in headers] + [f"{h}_acc" for h in headers] + ["timestamp"]
                    writer.writerow(header_row)
                row = [dataset_name_for_log, total]
                for h in headers:
                    row.append(results_counts.get(h, 0))
                for h in headers:
                    acc = results_counts.get(h, 0) / total if total > 0 else 0.0
                    row.append(f"{acc:.4f}")
                row.append(time.strftime("%Y-%m-%d %H:%M:%S"))
                writer.writerow(row)
            print(f"Wrote CSV summary to: {CSV_OUT}")
        else:
            print("CSV_OUT placeholder left; skipping CSV write. Set CSV_OUT env var to enable.")
    except Exception as e:
        print("Could not write CSV summary:", e)

    # PNG table
    try:
        import matplotlib.pyplot as plt
        def safe_fmt(x):
            try:
                return f"{x:.4f}"
            except Exception:
                return "nan"
        row_acc = [ safe_fmt(v) for v in vals ]
        fig_w = max(8, 0.6 * len(headers))
        fig, ax = plt.subplots(figsize=(fig_w, 1.8))
        ax.axis('off')
        table_data = [headers, row_acc]
        table = ax.table(cellText=table_data, loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 1.8)
        png_path = os.path.join(LOG_DIR if LOG_DIR and not LOG_DIR.startswith("to enter") else ".", f"{dataset_name_for_log}_ten_models_accuracy_table.png")
        plt.tight_layout()
        plt.savefig(png_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print("Saved accuracy PNG table to:", png_path)
    except Exception as e:
        print("Could not create/save PNG accuracy table:", e)
