#!/usr/bin/env python3
import argparse, csv
from tqdm import tqdm
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from src.utils.eval_datasets import load_truthfulqa_csv


def load_model(model_name, device_map="auto", dtype="bfloat16", load_in_4bit=False, load_in_8bit=False):
    kw = {}
    if dtype == "bfloat16":
        kw["torch_dtype"] = torch.bfloat16
    elif dtype == "float16":
        kw["torch_dtype"] = torch.float16
    elif dtype == "float32":
        kw["torch_dtype"] = torch.float32

    quant_kw = {}
    if load_in_4bit or load_in_8bit:
        quant_kw["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit
        )

    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    tok.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device_map, trust_remote_code=True, **kw, **quant_kw
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok, model


def build_prompt(tokenizer, question: str) -> str:
    msgs = [
        {"role": "user", "content": question.strip()},
    ]
    return tokenizer.apply_chat_template(
        msgs, tokenize=False, add_generation_prompt=True
    )


def generate_batch(model, tokenizer, prompts, max_new_tokens, do_sample, temperature, top_p, top_k):
    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    full_texts   = tokenizer.batch_decode(outputs,              skip_special_tokens=True)
    prompt_texts = tokenizer.batch_decode(inputs["input_ids"],  skip_special_tokens=True)

    answers = []
    for full, prompt in zip(full_texts, prompt_texts):
        full = full.strip()
        prompt = prompt.strip()
        if full.startswith(prompt):
            ans = full[len(prompt):].lstrip()
        else:
            ans = full
        answers.append(ans)
    return answers


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", required=True, help="HF model name or path")
    ap.add_argument("--input_path", required=True, help="Dataset path")
    ap.add_argument("--output_csv", required=True, help="Where to save generations")
    ap.add_argument("--max_new_tokens", type=int, default=256)
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--top_p", type=float, default=1.0)
    ap.add_argument("--top_k", type=int, default=50)
    ap.add_argument("--dtype", choices=["bfloat16", "float16", "float32"], default="bfloat16")
    ap.add_argument("--device_map", default="auto")
    ap.add_argument("--load_in_4bit", action="store_true")
    ap.add_argument("--load_in_8bit", action="store_true")
    ap.add_argument("--limit", type=int, default=0, help="Debug: only first N questions")
    ap.add_argument("--batch_size", type=int, default=1, help="Number of questions to generate per batch")
    args = ap.parse_args()

    tokenizer, model = load_model(
        args.model,
        device_map=args.device_map,
        dtype=args.dtype,
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
    )

    rows = load_truthfulqa_csv(args.input_path, limit=args.limit)

    out_path = Path(args.output_csv)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    with open(out_path, "w", newline="", encoding="utf-8") as f_out:
        writer = csv.writer(f_out)
        writer.writerow(["dataset", "model", "question_id", "question", "answer",
                         "temperature", "top_p", "top_k", "max_new_tokens"])

        do_sample = args.temperature > 0.0
        batch = []

        for qid, ex in tqdm(enumerate(rows), total=len(rows), desc="Generating answers"):
            question = ex["question"]
            prompt = build_prompt(tokenizer, question)
            batch.append((qid, ex, question, prompt))

            if len(batch) == args.batch_size:
                prompts = [b[3] for b in batch]
                answers = generate_batch(
                    model,
                    tokenizer,
                    prompts,
                    max_new_tokens=args.max_new_tokens,
                    do_sample=do_sample,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                )

                for (qid_b, ex_b, question_b, _), ans in zip(batch, answers):
                    question_id = ex_b.get("id", qid_b)
                    writer.writerow(["truthfulqa", args.model, question_id, question_b, ans,
                                     args.temperature, args.top_p, args.top_k, args.max_new_tokens])

                batch = []

        if batch:
            prompts = [b[3] for b in batch]
            answers = generate_batch(
                model,
                tokenizer,
                prompts,
                max_new_tokens=args.max_new_tokens,
                do_sample=do_sample,
                temperature=args.temperature,
                top_p=args.top_p,
                top_k=args.top_k,
            )

            for (qid_b, ex_b, question_b, _), ans in zip(batch, answers):
                question_id = ex_b.get("id", qid_b)
                writer.writerow(["truthfulqa", args.model, question_id, question_b, ans,
                                 args.temperature, args.top_p, args.top_k, args.max_new_tokens])

    print(f"Saved generations to {out_path}")


if __name__ == "__main__":
    main()
