# pubmedqa_evaluate_multi_lora.py
import os
import re
import math
import collections
import torch
import time
import csv
import warnings
import gc
import hashlib
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from typing import Tuple, List

# performance tweaks
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

# -------------------------------
# Config (placeholders in env)
# -------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN", "please enter HF_TOKEN")
MODEL_NAME = os.environ.get("MODEL_NAME", "please enter MODEL_NAME")

# --- Full case LoRA dirs (three) ---
TENSORS_DIR_FULL_MLPATT = os.environ.get("TENSORS_DIR_FULL_MLPATT", "to enter values of TENSORS_DIR_FULL_MLPATT")
TENSORS_DIR_FULL_ATT = os.environ.get("TENSORS_DIR_FULL_ATT", "to enter values of TENSORS_DIR_FULL_ATT")
TENSORS_DIR_FULL_MLP = os.environ.get("TENSORS_DIR_FULL_MLP", "to enter values of TENSORS_DIR_FULL_MLP")

# --- Top1 case LoRA dirs (three) ---
TENSORS_DIR_TOP1_MLPATT = os.environ.get("TENSORS_DIR_TOP1_MLPATT", "to enter values of TENSORS_DIR_TOP1_MLPATT")
TENSORS_DIR_TOP1_ATT = os.environ.get("TENSORS_DIR_TOP1_ATT", "to enter values of TENSORS_DIR_TOP1_ATT")
TENSORS_DIR_TOP1_MLP = os.environ.get("TENSORS_DIR_TOP1_MLP", "to enter values of TENSORS_DIR_TOP1_MLP")

# --- Top3 case LoRA dirs (three) ---
TENSORS_DIR_TOP3_MLPATT = os.environ.get("TENSORS_DIR_TOP3_MLPATT", "to enter values of TENSORS_DIR_TOP3_MLPATT")
TENSORS_DIR_TOP3_ATT = os.environ.get("TENSORS_DIR_TOP3_ATT", "to enter values of TENSORS_DIR_TOP3_ATT")
TENSORS_DIR_TOP3_MLP = os.environ.get("TENSORS_DIR_TOP3_MLP", "to enter values of TENSORS_DIR_TOP3_MLP")

# CSV / LOG locations placeholders
CSV_OUT = os.environ.get("CSV_OUT", "to enter values of CSV_OUT")
LOG_DIR = os.environ.get("LOG_DIR", "to enter values of LOG_DIR")

# -------------------------------
# Hyperparameters placeholders (must set env vars to integers)
# -------------------------------
def get_required_int_env(name: str) -> int:
    v = os.environ.get(name)
    if v is None:
        raise RuntimeError(f"Environment variable {name} is not set. Please set it to an integer.")
    try:
        return int(v)
    except Exception:
        raise RuntimeError(f"Environment variable {name} must be an integer. Got: {v!r}")

N = get_required_int_env("N_SAMPLES")
NUM_K_SAMPLING = get_required_int_env("NUM_K_SAMPLING")
BATCH_SIZE = get_required_int_env("BATCH_SIZE")
MAX_NEW_TOKENS = get_required_int_env("MAX_NEW_TOKENS")
PREVIEW = get_required_int_env("PREVIEW")

# multiprocessing workers
PARALLEL_WORKERS = int(os.environ.get("PARALLEL_WORKERS", str(max(1, (os.cpu_count() or 4) - 2))))

# ensure log dir and csv
os.makedirs(LOG_DIR, exist_ok=True)
if CSV_OUT and not os.path.isabs(CSV_OUT):
    CSV_OUT = os.path.join(LOG_DIR, CSV_OUT)
elif CSV_OUT:
    CSV_OUT = os.path.join(LOG_DIR, os.path.basename(CSV_OUT))
os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)

# -------------------------------
# Token pin/move helpers (pin CPU tensors then move to device)
# -------------------------------
def _pin_and_move_inputs(inputs_cpu, device):
    inputs_on_device = {}
    for k, v in inputs_cpu.items():
        try:
            v_p = v.pin_memory()
            inputs_on_device[k] = v_p.to(device, non_blocking=True)
        except Exception:
            try:
                inputs_on_device[k] = v.to(device, non_blocking=True)
            except Exception:
                inputs_on_device[k] = v.to(device)
    return inputs_on_device

# -------------------------------
# Load pretrained base (bfloat16 + low_cpu_mem_usage)
# -------------------------------
def load_pretrained(model_name, hf_token):
    tok = AutoTokenizer.from_pretrained(model_name, token=hf_token, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True,
        token=hf_token,
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, 'bos_token', None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass
    model.eval()
    print("Loaded pretrained base model.")
    return model, tok

# -------------------------------
# Load LoRA (keeps mapping/loading logic but minimal prints)
# -------------------------------
def load_lora(base_model, hf_token, tensors_dir, r, alpha):
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True,
        token=hf_token,
    )
    tok = AutoTokenizer.from_pretrained(base_model, token=hf_token, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(tok, 'bos_token', None) is None:
        tok.bos_token = tok.eos_token
    try:
        model.config.pad_token_id = tok.pad_token_id
        model.config.eos_token_id = tok.eos_token_id
        model.config.bos_token_id = tok.bos_token_id
    except Exception:
        pass

    cfg = LoraConfig(
        r=r, lora_alpha=alpha,
        target_modules=["q_proj","k_proj","v_proj","o_proj","down_proj","up_proj","gate_proj"],
        task_type="CAUSAL_LM"
    )
    peft_model = get_peft_model(model, cfg)

    # read .pt files if provided
    state_raw = {}
    if tensors_dir and os.path.exists(tensors_dir):
        for root, _, files in os.walk(tensors_dir):
            for f in files:
                if f.endswith(".pt"):
                    k = f.replace(".pt", "")
                    path = os.path.join(root, f)
                    try:
                        tensor = torch.load(path, map_location="cpu")
                        state_raw[k] = tensor
                    except Exception:
                        pass
    else:
        if tensors_dir:
            print(f"Warning: tensors_dir provided but path does not exist: {tensors_dir}")
        else:
            print("Warning: no tensors_dir provided; returning PEFT wrapper without additional weights.")

    # mapping heuristics (same logic, but no verbose printing)
    target_state = peft_model.state_dict()
    target_keys = set(target_state.keys())
    mapped = {}
    unmatched_raw = []
    unmatched_targets = set(target_keys)
    suffixes = ["", ".weight", ".default.weight"]
    for raw_k, tensor in state_raw.items():
        found = False
        if raw_k in target_keys:
            mapped[raw_k] = tensor
            unmatched_targets.discard(raw_k)
            found = True
        else:
            for suf in suffixes:
                cand = raw_k + suf
                if cand in target_keys:
                    mapped[cand] = tensor
                    unmatched_targets.discard(cand)
                    found = True
                    break
            if not found:
                for prefix in ["base_model.", "model.", ""]:
                    for suf in suffixes:
                        cand = prefix + raw_k + suf
                        if cand in target_keys:
                            mapped[cand] = tensor
                            unmatched_targets.discard(cand)
                            found = True
                            break
                    if found:
                        break
        if not found:
            unmatched_raw.append(raw_k)

    # apply mapped weights
    peft_model.load_state_dict(mapped, strict=False)

    # minimal print to indicate load
    if tensors_dir:
        print(f"Loaded LoRA weights from: {tensors_dir}")
    peft_model.eval()
    tok.pad_token = tok.eos_token
    return peft_model, tok

# -------------------------------
# Prompt builder for PubMedQA
# -------------------------------
def build_pubmedqa_prompt(context, question):
    if isinstance(context, (list, tuple)):
        context_str = " ".join([c.strip() for c in context if c])
    else:
        context_str = str(context).strip()
    prompt = (
        f"Context:\n{context_str}\n\n"
        f"Question: {question.strip()}\n\n"
        "Based on the context above, answer the question with exactly 'yes' or 'no' (lowercase), "
        "and do NOT provide any explanation. Answer:\n"
    )
    return prompt

# -------------------------------
# Generation / extraction helpers
# -------------------------------
def generate_answer(model, tok, context, question, max_new_tokens=512):
    prompt = build_pubmedqa_prompt(context, question)
    inputs_cpu = tok(prompt, return_tensors="pt")
    inputs = _pin_and_move_inputs(inputs_cpu, model.device)

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tok.eos_token_id
        )

    output_ids = out[0]
    input_len = inputs["input_ids"].shape[-1]
    continuation_ids = output_ids[input_len:]
    continuation = tok.decode(continuation_ids, skip_special_tokens=True).strip()

    m = re.search(r"\n\s*(Q[:\d ]|Q\d+:|Q:)", continuation)
    if m:
        continuation = continuation[:m.start()].strip()
    else:
        m2 = re.search(r"\n\s*(OR\b|Stop\b)", continuation, flags=re.IGNORECASE)
        if m2:
            continuation = continuation[:m2.start()].strip()

    return continuation

def generate_answer_self_consistency(model, tok, context, question,
                                     num_samples=16,
                                     batch_size=4,
                                     max_new_tokens=256,
                                     temperature=0.7,
                                     top_p=0.95,
                                     no_repeat_ngram_size=3,
                                     repetition_penalty=1.1):
    prompt = build_pubmedqa_prompt(context, question)
    inputs_cpu = tok(prompt, return_tensors="pt")
    inputs = _pin_and_move_inputs(inputs_cpu, model.device)

    candidates = []
    labels = []
    num_loops = math.ceil(num_samples / batch_size)
    produced = 0

    with torch.no_grad():
        for loop_idx in range(num_loops):
            this_bs = min(batch_size, num_samples - produced)
            outs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=this_bs,
                pad_token_id=tok.eos_token_id,
                no_repeat_ngram_size=no_repeat_ngram_size,
                repetition_penalty=repetition_penalty
            )
            for j in range(outs.shape[0]):
                seq = outs[j]
                input_len = inputs["input_ids"].shape[-1]
                cont_ids = seq[input_len:]
                cont = tok.decode(cont_ids, skip_special_tokens=True).strip()
                m = re.search(r"\n\s*(Q[:\d ]|Q\d+:|Q:)", cont)
                if m:
                    cont = cont[:m.start()].strip()
                else:
                    m2 = re.search(r"\n\s*(OR\b|Stop\b)", cont, flags=re.IGNORECASE)
                    if m2:
                        cont = cont[:m2.start()].strip()
                candidates.append(cont)
                labels.append(extract_yes_no_from_text(cont))
            produced += this_bs

    label_counts = collections.Counter([l for l in labels if l is not None])
    chosen_label = None
    chosen_continuation = None
    if label_counts:
        chosen_label, _ = label_counts.most_common(1)[0]
        for cont, lab in zip(candidates, labels):
            if lab == chosen_label:
                chosen_continuation = cont
                break
    else:
        text_counts = collections.Counter(candidates)
        if text_counts:
            chosen_continuation, _ = text_counts.most_common(1)[0]
            chosen_label = extract_yes_no_from_text(chosen_continuation)
        else:
            chosen_continuation = ""
            chosen_label = None

    return chosen_continuation, candidates, chosen_label

def extract_yes_no_from_text(text):
    if text is None:
        return None
    text_l = text.lower()
    m = re.search(r"\b(yes|no)\b", text_l)
    if m:
        return m.group(1)
    m2 = re.search(r"\b(true|false)\b", text_l)
    if m2:
        return 'yes' if m2.group(1) == 'true' else 'no'
    tokens = re.findall(r"[A-Za-z']+", text_l)
    if tokens:
        t0 = tokens[0]
        if t0.startswith('y'):
            return 'yes'
        if t0.startswith('n'):
            return 'no'
    return None

def evaluate_prediction_label(pred_text, gold_label):
    pred_label = extract_yes_no_from_text(pred_text)
    return (pred_label == (gold_label.lower() if isinstance(gold_label, str) else gold_label)), pred_label

# -------------------------------
# Evaluate loaded model with tqdm progress
# -------------------------------
def evaluate_loaded_model(model, tok, ds, model_tag, log_fh=None, pool=None):
    correct = 0
    total = len(ds)
    pbar = tqdm(total=total, desc=f"{model_tag}", unit="it")
    try:
        for i, ex in enumerate(ds):
            question = ex.get("question", "").strip()
            context = ex.get("context", "")
            gold = ex.get("final_decision", "")

            if NUM_K_SAMPLING > 1:
                chosen, cand_list, label = generate_answer_self_consistency(
                    model, tok, context, question,
                    num_samples=NUM_K_SAMPLING,
                    batch_size=BATCH_SIZE,
                    max_new_tokens=512,
                    temperature=0.7,
                    top_p=0.95
                )
                pred_label = label
            else:
                chosen = generate_answer(model, tok, context, question, max_new_tokens=MAX_NEW_TOKENS)
                _, pred_label = evaluate_prediction_label(chosen, gold)

            is_corr = (str(pred_label).lower() == str(gold).lower()) if pred_label is not None and gold is not None else False
            correct += int(is_corr)

            per_sample_lines = []
            per_sample_lines.append("\n" + "="*80)
            per_sample_lines.append(f"[{model_tag}] Index {i} | Question:")
            per_sample_lines.append(question)
            per_sample_lines.append(f"Context (first 400 chars): {str(context)[:400]}\n")
            per_sample_lines.append(f"Ground Truth: {gold}\n")
            per_sample_lines.append(f"Model output: {chosen}")
            per_sample_lines.append(f"Interpreted Pred: {pred_label} | Correct? {is_corr}")
            per_sample_text = "\n".join(per_sample_lines)

            if i < PREVIEW:
                print(per_sample_text)
            if log_fh:
                try:
                    log_fh.write(per_sample_text + "\n")
                    log_fh.flush()
                except Exception:
                    pass

            pbar.update(1)
    finally:
        pbar.close()

    return correct

# -------------------------------
# Entrypoint: sequentially load each model, evaluate, unload
# -------------------------------
if __name__ == "__main__":
    # reload placeholders from env (if user set them)
    HF_TOKEN = os.environ.get("HF_TOKEN", HF_TOKEN)
    MODEL_NAME = os.environ.get("MODEL_NAME", MODEL_NAME)

    # load dataset once
    dataset_name_for_log = "PubMedQA"
    os.makedirs(LOG_DIR, exist_ok=True)
    log_file_path = os.path.join(LOG_DIR, f"{dataset_name_for_log}_detailed_log.txt")

    ds = load_dataset("qiaojin/PubMedQA", "pqa_artificial", split=f"train[:{N}]")
    print(f"\nLoaded {len(ds)} samples from qiaojin/PubMedQA (pqa_artificial) - N = {N}\n")

    try:
        log_fh = open(log_file_path, "w", encoding="utf-8")
    except Exception as e:
        print("Could not open log file for writing:", e)
        log_fh = None

    pool = None
    try:
        if PARALLEL_WORKERS > 1:
            pool = None  # keep parsing single-threaded in this script; leave option if needed later
    except Exception:
        pool = None

    # list of 10 models in order
    model_list = [
        ("Pretrained", None),
        ("Full_MLPATT", TENSORS_DIR_FULL_MLPATT),
        ("Full_ATT", TENSORS_DIR_FULL_ATT),
        ("Full_MLP", TENSORS_DIR_FULL_MLP),
        ("Top1_MLPATT", TENSORS_DIR_TOP1_MLPATT),
        ("Top1_ATT", TENSORS_DIR_TOP1_ATT),
        ("Top1_MLP", TENSORS_DIR_TOP1_MLP),
        ("Top3_MLPATT", TENSORS_DIR_TOP3_MLPATT),
        ("Top3_ATT", TENSORS_DIR_TOP3_ATT),
        ("Top3_MLP", TENSORS_DIR_TOP3_MLP),
    ]

    results_correct = {}
    results_acc = {}

    for label, dirpath in model_list:
        if label == "Pretrained":
            print(f"\n--- Loading & evaluating: {label} ---\n")
            model_loaded, tok_loaded = load_pretrained(MODEL_NAME, HF_TOKEN)
        else:
            # skip if placeholder or missing
            if not dirpath or dirpath.startswith("to enter") or not os.path.exists(dirpath):
                print(f"Skipping {label}: tensors dir missing or placeholder ({dirpath}).")
                results_correct[label] = 0
                results_acc[label] = 0.0
                continue
            print(f"\n--- Loading & evaluating: {label} from {dirpath} ---\n")
            model_loaded, tok_loaded = load_lora(MODEL_NAME, HF_TOKEN, dirpath, r=16, alpha=32)

        try:
            correct = evaluate_loaded_model(model_loaded, tok_loaded, ds, label, log_fh=log_fh, pool=pool)
            acc = correct / len(ds) if len(ds) > 0 else 0.0
            results_correct[label] = correct
            results_acc[label] = acc
        except Exception as e:
            print(f"Exception while evaluating {label}: {e}")
            results_correct[label] = 0
            results_acc[label] = 0.0

        # unload model and free memory
        try:
            del model_loaded
            del tok_loaded
            torch.cuda.empty_cache()
            gc.collect()
            print(f"Unloaded {label} and freed GPU memory.\n")
        except Exception:
            pass

    if log_fh:
        log_fh.close()

    total = len(ds)
    print("\n\nFINAL AGGREGATED ACCURACY (ALL MODELS) — Dataset: {}  N = {}\n".format(dataset_name_for_log, total))
    for label, _ in model_list:
        correct = results_correct.get(label, 0)
        acc = results_acc.get(label, 0.0)
        print(f"{label}: {correct} / {total} = {acc:.3f}")

    # write CSV summary (10 model columns)
    try:
        file_exists = os.path.exists(CSV_OUT)
        with open(CSV_OUT, "a", newline="", encoding="utf-8") as fh:
            writer = csv.writer(fh)
            if not file_exists:
                writer.writerow([
                    "dataset", "N",
                    "pretrained_correct",
                    "full_mlpatt_correct", "full_att_correct", "full_mlp_correct",
                    "top1_mlpatt_correct", "top1_att_correct", "top1_mlp_correct",
                    "top3_mlpatt_correct", "top3_att_correct", "top3_mlp_correct",
                    "acc_pretrained",
                    "acc_full_mlpatt", "acc_full_att", "acc_full_mlp",
                    "acc_top1_mlpatt", "acc_top1_att", "acc_top1_mlp",
                    "acc_top3_mlpatt", "acc_top3_att", "acc_top3_mlp",
                    "timestamp"
                ])
            writer.writerow([
                dataset_name_for_log, total,
                results_correct.get("Pretrained", 0),
                results_correct.get("Full_MLPATT", 0), results_correct.get("Full_ATT", 0), results_correct.get("Full_MLP", 0),
                results_correct.get("Top1_MLPATT", 0), results_correct.get("Top1_ATT", 0), results_correct.get("Top1_MLP", 0),
                results_correct.get("Top3_MLPATT", 0), results_correct.get("Top3_ATT", 0), results_correct.get("Top3_MLP", 0),
                f"{results_acc.get('Pretrained', 0.0):.4f}",
                f"{results_acc.get('Full_MLPATT', 0.0):.4f}", f"{results_acc.get('Full_ATT', 0.0):.4f}", f"{results_acc.get('Full_MLP', 0.0):.4f}",
                f"{results_acc.get('Top1_MLPATT', 0.0):.4f}", f"{results_acc.get('Top1_ATT', 0.0):.4f}", f"{results_acc.get('Top1_MLP', 0.0):.4f}",
                f"{results_acc.get('Top3_MLPATT', 0.0):.4f}", f"{results_acc.get('Top3_ATT', 0.0):.4f}", f"{results_acc.get('Top3_MLP', 0.0):.4f}",
                time.strftime("%Y-%m-%d %H:%M:%S")
            ])
        print(f"Wrote CSV summary to: {CSV_OUT}")
    except Exception as e:
        warnings.warn(f"Could not write CSV summary to {CSV_OUT}: {e}")

    # PNG table (10 columns)
    try:
        import matplotlib.pyplot as plt
        headers = ["Pretrained",
                   "Full MLP+ATT", "Full ATT", "Full MLP",
                   "Top1 MLP+ATT", "Top1 ATT", "Top1 MLP",
                   "Top3 MLP+ATT", "Top3 ATT", "Top3 MLP"]
        vals = [f"{results_acc.get('Pretrained', 0.0):.4f}",
                f"{results_acc.get('Full_MLPATT', 0.0):.4f}", f"{results_acc.get('Full_ATT', 0.0):.4f}", f"{results_acc.get('Full_MLP', 0.0):.4f}",
                f"{results_acc.get('Top1_MLPATT', 0.0):.4f}", f"{results_acc.get('Top1_ATT', 0.0):.4f}", f"{results_acc.get('Top1_MLP', 0.0):.4f}",
                f"{results_acc.get('Top3_MLPATT', 0.0):.4f}", f"{results_acc.get('Top3_ATT', 0.0):.4f}", f"{results_acc.get('Top3_MLP', 0.0):.4f}"]
        fig, ax = plt.subplots(figsize=(12, 2.8))
        ax.axis('off')
        table = ax.table(cellText=[headers, vals], loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2.0)
        png_path = os.path.join(LOG_DIR, f"{dataset_name_for_log}_accuracy_table.png")
        plt.tight_layout()
        plt.savefig(png_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print("Saved accuracy PNG table to:", png_path)
    except Exception as e:
        warnings.warn(f"Could not create/save PNG accuracy table: {e}")
