#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import random
from pathlib import Path

import datasets
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from huggingface_hub import hf_hub_download


def parse_args():
    p = argparse.ArgumentParser(description="Generate AlpacaEval 2.0 outputs from a local (or LoRA) model.")

    # Paths / names
    p.add_argument("--batch-size", type=int, default=1,
                   help="Batch size (try >1 if you have enough GPU RAM).")
    p.add_argument("--local-checkpoint-dir", type=str, required=True,
                   help="Path to your local checkpoint directory (e.g., test_dpo/checkpoint-5).")
    p.add_argument("--base-model", type=str, default=None,
                   help="Base model name/path if --local-checkpoint-dir holds a LoRA/PEFT adapter.")
    p.add_argument("--outputs-json", type=str, default="alpacaeval_outputs.json",
                   help="Where to save generations for AlpacaEval.")
    p.add_argument("--generator-name", type=str, default="my_local_llm",
                   help="String tag stored in each output row under 'generator'.")
    p.add_argument("--max-samples", type=int, default=None,
                   help="Number of samples to keep in the dataset before inference. Default: all.")

    # Inference
    p.add_argument("--max-new-tokens", type=int, default=512, help="Max new tokens to generate.")
    p.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature.")
    p.add_argument("--top-p", type=float, default=0.9, help="Top-p nucleus sampling.")
    p.add_argument("--use-bf16", action="store_true",
                   help="Prefer bfloat16 on GPU. If not set, uses float16 on GPU. CPU uses float32.")
    p.add_argument("--force-cpu", action="store_true",
                   help="Force CPU inference (overrides GPU).")

    # Reproducibility
    p.add_argument("--seed", type=int, default=0,
                   help="Random seed for reproducibility (affects sampling, dropout).")

    # Dataset
    p.add_argument("--dataset", type=str, default="tatsu-lab/alpaca_eval",
                   help="HuggingFace dataset to load.")
    p.add_argument("--split", type=str, default="eval",
                   help="Which split to use (e.g., eval, test).")
    p.add_argument("--subset", type=str, default=None,
                   help="Optional subset/configuration name for the dataset.")

    # Include extra fields
    p.add_argument(
        "--include-extra-fields",
        action="store_true",
        help="If set, also include dataset-specific fields like 'chosen' and 'rejected'. "
             "Default behavior is AlpacaEval-compatible minimal output."
    )

    # New mode: dataset-chosen outputs only
    p.add_argument(
        "--use-dataset-chosen",
        action="store_true",
        help="If set, skip generation and instead output the dataset's 'chosen' responses."
    )

    return p.parse_args()


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_dataset_examples(name: str, split: str, subset: str | None):
    if name == "tatsu-lab/alpaca_eval":
        subset = "alpaca_eval"
    ds = datasets.load_dataset(name, subset, trust_remote_code=True) if subset \
         else datasets.load_dataset(name, trust_remote_code=True)
    return ds[split]


def detect_checkpoint_kind(ckpt_dir: Path | str):
    if isinstance(ckpt_dir, str):
        # Hub ID: cannot detect local files, so just say "full model"
        return True, False
    # Local path case
    has_adapter = (ckpt_dir / "adapter_model.safetensors").exists()
    has_full_model = (ckpt_dir / "model.safetensors").exists() or (ckpt_dir / "pytorch_model.bin").exists()
    return has_full_model, has_adapter


def pick_tokenizer_path(ckpt_dir: Path | str, base_model: str | None, use_adapter: bool):
    local_tok_files = [
        "tokenizer.json", "tokenizer.model", "vocab.json", "merges.txt", "special_tokens_map.json"
    ]

    # Case 1: local checkpoint folder
    if isinstance(ckpt_dir, Path):
        if any((ckpt_dir / f).exists() for f in local_tok_files):
            return str(ckpt_dir)
        return base_model if use_adapter else str(ckpt_dir)

    # Case 2: HF Hub model id (string)
    if isinstance(ckpt_dir, str):
        # For hub IDs, don't check files; transformers will resolve automatically
        return base_model if use_adapter else ckpt_dir


def build_prompt_fn(tokenizer):
    def _prompt(instruction: str) -> str:
        try:
            if getattr(tokenizer, "chat_template", None):
                messages = [{"role": "user", "content": instruction}]
                return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        except Exception:
            pass
        return f"### Instruction:\n{instruction}\n\n### Response:\n"
    return _prompt


def get_prompts(batch, dataset, build_prompt):
    if dataset == "tatsu-lab/alpaca_eval":
        return [build_prompt(ex["instruction"]) for ex in batch]
    elif dataset == "trl-lib/ultrafeedback_binarized":
        return [build_prompt(ex['chosen'][0]['content']) for ex in batch]
    elif dataset == 'trl-lib/tldr-preference':
        return [build_prompt(ex['prompt']) for ex in batch]


def get_instruction(batch_ex, dataset):
    if "instruction" in batch_ex:
        return batch_ex["instruction"]
    if "chosen" in batch_ex and dataset == "trl-lib/ultrafeedback_binarized":
        return batch_ex["chosen"][0]["content"]
    if "prompt" in batch_ex and dataset == "trl-lib/tldr-preference":
        return batch_ex["prompt"]


def main():
    args = parse_args()
    set_seed(args.seed)
    print(f"[INFO] Global seed set to {args.seed}")

    # Load dataset
    eval_full = load_dataset_examples(args.dataset, args.split, args.subset)
    if args.max_samples is not None:
        n = max(0, min(args.max_samples, len(eval_full)))
        DATASET_SEED=0
        eval_data = eval_full.shuffle(seed=DATASET_SEED).select(range(n))
    else:
        eval_data = eval_full
    eval_data = eval_data.to_list()

    print(f"[INFO] Using {len(eval_data)} / {len(eval_full)} examples from {args.dataset}.")

    # --- Early branch: dataset chosen mode ---
    if args.use_dataset_chosen:
        rows = []
        for ex in eval_data:
            row = {
                "instruction": get_instruction(ex, args.dataset),
                "output": ex.get("chosen", ex.get("output", "")),
                "generator": args.generator_name,
            }
            if args.include_extra_fields:
                for key in ("chosen", "rejected", "prompt", "input", "label"):
                    if key in ex:
                        row[key] = ex[key]
            rows.append(row)

        out_file = args.outputs_json.replace(".json", "-chosen.json")
        Path(out_file).parent.mkdir(parents=True, exist_ok=True)
        with open(out_file, "w", encoding="utf-8") as f:
            json.dump(rows, f, ensure_ascii=False, indent=2)

        print(f"[OK] Wrote {len(rows)} dataset 'chosen' outputs to {Path(out_file).resolve()}")
        return  # 🚪 exit early

    from huggingface_hub import model_info

    def resolve_checkpoint_path(arg: str) -> str:
        p = Path(arg)
        if p.exists():
            return str(p)  # local
        else:
            # try hub id
            try:
                model_info(arg)  # checks if it's a valid hub model
                return arg
            except Exception:
                raise SystemExit(f"[ERROR] '{arg}' is neither a local folder nor a valid HF model id")

    # --- Normal model-based generation ---
    ckpt = resolve_checkpoint_path(args.local_checkpoint_dir)
    has_full_model, has_adapter = detect_checkpoint_kind(ckpt)

    if not (has_full_model or has_adapter):
        raise SystemExit(
            f"[ERROR] '{ckpt}' is not a recognized checkpoint folder.\n"
            "Looked for 'model.safetensors' / 'pytorch_model.bin' (full model) "
            "or 'adapter_model.safetensors' (LoRA/PEFT)."
        )
    if has_adapter and not args.base_model:
        raise SystemExit(
            "[ERROR] Detected a LoRA/PEFT adapter but --base-model was not provided.\n"
            "Please pass the base model name/path used during fine-tuning."
        )

    # Device / dtype
    if args.force_cpu:
        device_map = "cpu"
        torch_dtype = torch.float32
    else:
        device_map = "auto"
        if torch.cuda.is_available():
            torch_dtype = torch.bfloat16 if args.use_bf16 else torch.float16
        else:
            device_map = "cpu"
            torch_dtype = torch.float32

    # Tokenizer
    tok_path = pick_tokenizer_path(ckpt, args.base_model, has_adapter)
    tokenizer = AutoTokenizer.from_pretrained(tok_path, use_fast=True)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # Model
    if has_full_model:
        model = AutoModelForCausalLM.from_pretrained(
            str(ckpt),
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
    else:
        try:
            from peft import PeftModel
        except Exception as e:
            raise SystemExit(
                "[ERROR] This checkpoint is a PEFT/LoRA adapter but the 'peft' package is not installed.\n"
                "Install via: pip install peft"
            ) from e

        base_model = AutoModelForCausalLM.from_pretrained(
            args.base_model,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        model = PeftModel.from_pretrained(base_model, str(ckpt))
        try:
            model = model.merge_and_unload()
        except Exception:
            pass

    model.eval()
    build_prompt = build_prompt_fn(tokenizer)

    gen_cfg = GenerationConfig(
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        do_sample=True if args.temperature > 0 else False,
    )

    # Generation loop
    rows = []
    B = args.batch_size
    for i in range(0, len(eval_data), B):
        print(f"[INFO] Generating batch {i // B + 1} of {(len(eval_data) + B - 1) // B}...")
        batch = eval_data[i:i + B]
        prompts = get_prompts(batch, args.dataset, build_prompt)
        inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

        with torch.no_grad():
            out = model.generate(
                **inputs,
                generation_config=gen_cfg,
                pad_token_id=tokenizer.pad_token_id,
                use_cache=True,
            )

        for j in range(len(batch)):
            in_len = inputs["input_ids"][j].shape[-1]
            gen_ids = out[j][in_len:]
            text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

            row = {
                "instruction": get_instruction(batch[j], args.dataset),
                "output": text,
                "generator": args.generator_name,
            }
            if args.include_extra_fields:
                for key in ("chosen", "rejected", "prompt", "input", "label"):
                    if key in batch[j]:
                        row[key] = batch[j][key]
            rows.append(row)

    # Output path
    output_file = args.outputs_json
    if args.include_extra_fields:
        output_file = output_file.replace(".json", "-extended.json")

    out_path = Path(output_file)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(rows, f, ensure_ascii=False, indent=2)

    print(f"[OK] Wrote {len(rows)} generations to {out_path.resolve()}")


if __name__ == "__main__":
    main()
