import argparse
import json
import os
import re
import sys
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


# -------------------------
# JSONL I/O
# -------------------------

def read_jsonl(path: str) -> Iterable[Dict[str, Any]]:
    """Stream JSONL records. Invalid lines are skipped."""
    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                sys.stderr.write(f"[warn] Invalid JSON at line {ln}\n")


def append_jsonl(path: str, records: List[Dict[str, Any]]) -> None:
    """Append a batch of JSON objects to a JSONL file."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


# -------------------------
# Prompt builder
# -------------------------

PREFIX_PROMPT = "[Round 0] USER:\n"
ASSISTANT_PROMPT = "\nPlease reason step by step, and put your final answer within \\boxed{}. ASSISTANT:\n"
THINK_TAG = "<think>\n"

def build_prompt(question: str) -> str:
    """Build the fixed prompt for the given question."""
    q = question if isinstance(question, str) else str(question)
    return f"{PREFIX_PROMPT}{q}{ASSISTANT_PROMPT}{THINK_TAG}"


# -------------------------
# Tokenization helpers
# -------------------------

def ensure_pad_token(tokenizer: AutoTokenizer, model: AutoModelForCausalLM) -> None:
    """Ensure pad_token is set for left padding models that need attention masks."""
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
            model.resize_token_embeddings(len(tokenizer))
    tokenizer.padding_side = "left"


# -------------------------
# Auto-truncate forward to avoid OOM
# -------------------------

def _truncate_keep_target(inp_ids: List[int],
                          tgt_ids: List[int],
                          cap: int,
                          min_tgt_keep: int = 1,
                          min_inp_keep: int = 0) -> Tuple[List[int], List[int]]:
    """
    Truncate inp_ids + tgt_ids to length <= cap, preferring to keep target tokens.
    Returns trimmed (inp_cut, tgt_cut).
    """
    total = len(inp_ids) + len(tgt_ids)
    if cap >= total:
        return inp_ids, tgt_ids

    keep_tgt = max(min_tgt_keep, min(len(tgt_ids), cap))
    keep_inp = max(min_inp_keep, cap - keep_tgt)
    keep_inp = min(keep_inp, len(inp_ids))
    keep_tgt = min(keep_tgt, cap - keep_inp)
    return inp_ids[:keep_inp], tgt_ids[:keep_tgt]


def auto_truncate_forward(
    model: AutoModelForCausalLM,
    device: torch.device,
    inp_ids: List[int],
    tgt_ids: List[int],
    ctx_len: Optional[int] = None,
    initial_cap: Optional[int] = None,
    shrink_factor: float = 0.8,
    min_tgt_keep: int = 1,
    min_inp_keep: int = 0,
    min_total_keep: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, int]]:
    """
    Attempt a forward pass with increasing truncation until it fits GPU memory.

    Returns:
        logits_out, input_ids_tensor, attention_mask_tensor, meta_dict
    """
    total_len = len(inp_ids) + len(tgt_ids)
    cap = total_len
    if ctx_len is not None:
        cap = min(cap, int(ctx_len))
    if initial_cap:
        cap = min(cap, int(initial_cap))
    cap = max(min_total_keep, cap)

    while True:
        inp_cut, tgt_cut = _truncate_keep_target(inp_ids, tgt_ids, cap,
                                                 min_tgt_keep=min_tgt_keep,
                                                 min_inp_keep=min_inp_keep)
        full_ids = inp_cut + tgt_cut
        if len(full_ids) < min_total_keep:
            raise RuntimeError(f"Auto-truncate failed: cap={cap}, len(full)={len(full_ids)}")

        input_ids_tensor = torch.tensor([full_ids], dtype=torch.long, device=device)
        attn_mask_tensor = torch.ones_like(input_ids_tensor, dtype=torch.long, device=device)

        try:
            with torch.inference_mode():
                outputs = model(
                    input_ids=input_ids_tensor,
                    attention_mask=attn_mask_tensor,
                    use_cache=False,
                )
            meta = {
                "orig_total_len": int(total_len),
                "cap_used": int(cap),
                "inp_len_used": int(len(inp_cut)),
                "tgt_len_used": int(len(tgt_cut)),
                "attn_len": int(attn_mask_tensor.sum().item()),
            }
            return outputs.logits, input_ids_tensor, attn_mask_tensor, meta

        except RuntimeError as e:
            if "out of memory" not in str(e).lower():
                raise
            # Best-effort cache clear on CUDA
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
            cap_next = int(max(min_total_keep, cap * shrink_factor))
            if cap_next >= cap:
                raise RuntimeError(
                    f"OOM even at cap={cap}. orig_total_len={total_len}, ctx_len={ctx_len}"
                ) from e
            cap = cap_next


# -------------------------
# Entropy computation
# -------------------------

def compute_next_token_entropy_nats(logits: torch.Tensor,
                                    attn_mask: torch.Tensor) -> torch.Tensor:
    """
    Compute per-position next-token entropy in nats.

    Args:
        logits: Tensor of shape [B, T, V].
        attn_mask: Tensor of shape [B, T] with 1 for valid tokens.

    Returns:
        entropies: Tensor of shape [B, T-1] with H at each next-token position.
    """
    # Align for next-token predictions
    shift_logits = logits[:, :-1, :]         # [B, T-1, V]
    shift_mask = attn_mask[:, 1:]           # [B, T-1]
    log_probs = F.log_softmax(shift_logits, dim=-1)
    probs = log_probs.exp()
    entropy = -(probs * log_probs).sum(dim=-1)  # [B, T-1]
    # Mask out invalid positions
    entropy = entropy * shift_mask.to(entropy.dtype)
    return entropy


# -------------------------
# Core processing for one record
# -------------------------

def process_record(
    rec: Dict[str, Any],
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    device: torch.device,
    ctx_len: Optional[int],
    initial_cap: Optional[int],
    shrink_factor: float
) -> Optional[Dict[str, Any]]:
    """Compute output token entropies for one JSON record."""
    if "question" not in rec or "cot" not in rec:
        return None

    prompt = build_prompt(rec["question"])
    target = rec["cot"] if isinstance(rec["cot"], str) else str(rec["cot"])
    if target.startswith(THINK_TAG):
        target = target[len(THINK_TAG):]

    # Encode without adding special tokens
    inp_ids = tokenizer.encode(prompt, add_special_tokens=False)
    tgt_ids = tokenizer.encode(target, add_special_tokens=False)

    # Forward with auto-truncation
    logits, input_ids_tensor, attn_mask_tensor, meta = auto_truncate_forward(
        model=model,
        device=device,
        inp_ids=inp_ids,
        tgt_ids=tgt_ids,
        ctx_len=ctx_len,
        initial_cap=initial_cap,
        shrink_factor=shrink_factor,
        min_tgt_keep=1,
        min_inp_keep=0,
        min_total_keep=2,
    )

    # Compute next-token entropy
    entropy = compute_next_token_entropy_nats(logits, attn_mask_tensor)  # [1, T-1]

    # Prepare token strings
    tok_ids = input_ids_tensor[0].tolist()
    tok_pieces = tokenizer.convert_ids_to_tokens(tok_ids)

    inp_len_used = meta["inp_len_used"]
    T = int(input_ids_tensor.shape[1])

    # Extract entropies for output portion (positions t = inp_len_used .. T-1)
    out_entries: List[Dict[str, Any]] = []
    for t in range(inp_len_used, T):  # token index t; entropy index s = t-1
        s = t - 1
        if s < 0:
            continue
        if attn_mask_tensor[0, t].item() == 0:
            break
        # Build a readable token string for the single piece
        token_string = tokenizer.convert_tokens_to_string([tok_pieces[t]])
        out_entries.append({
            "token_id": int(tok_ids[t]),
            "token_string": token_string,
            "entropy_nats": float(entropy[0, s].item()),
        })

    # Build output record
    out = {
        "identity": rec.get("identity"),
        "question": rec.get("question"),
        "question_type": rec.get("question_type"),
        "answer": rec.get("answer"),
        "cot": rec.get("cot"),
        "pattern": rec.get("pattern", None),
        "output_token_entropies": out_entries,
    }
    return out


# -------------------------
# Main
# -------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Compute token-level entropies for CoT outputs using a HF causal LLM."
    )
    p.add_argument("--input_path", type=str, required=True, help="Input JSONL file.")
    p.add_argument("--output_dir", type=str, required=True, help="Directory to write JSONL output.")
    p.add_argument("--model_path", type=str, required=True, help="HF repo id or local model path.")
    p.add_argument("--gpu_id", type=int, default=0, help="CUDA device index.")
    p.add_argument("--dtype", type=str, default="bfloat16",
                   choices=["bfloat16", "float16", "float32"],
                   help="Computation dtype for model weights.")
    p.add_argument("--initial_cap", type=int, default=None,
                   help="Optional initial cap on effective sequence length (tokens).")
    p.add_argument("--shrink_factor", type=float, default=0.8,
                   help="OOM backoff factor for auto truncation.")
    p.add_argument("--progress", action="store_true", help="Show per-record progress bar.")
    return p.parse_args()


def main() -> None:
    args = parse_args()

    # Resolve device
    if not torch.cuda.is_available():
        sys.stderr.write("[warn] CUDA not available; running on CPU may be very slow.\n")
    device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu_id)

    # Load model and tokenizer
    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    torch_dtype = dtype_map[args.dtype]

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch_dtype,
        trust_remote_code=True
    ).to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        trust_remote_code=True,
        padding_side="left"
    )
    ensure_pad_token(tokenizer, model)

    # Determine context length softly
    ctx_len = getattr(model.config, "max_position_embeddings", None)
    tok_ctx = getattr(tokenizer, "model_max_length", None)
    if ctx_len is None and isinstance(tok_ctx, int) and tok_ctx < 10**7:
        ctx_len = int(tok_ctx)

    # Prepare output path
    out_fn = os.path.basename(args.input_path)
    out_path = os.path.join(args.output_dir, out_fn)
    os.makedirs(args.output_dir, exist_ok=True)

    # Process records
    it = read_jsonl(args.input_path)
    it = tqdm(it, desc="Computing CoT entropies") if args.progress else it

    batch: List[Dict[str, Any]] = []
    batch_cap = 64  # write in chunks to reduce I/O overhead
    written = 0

    for rec in it:
        try:
            out_rec = process_record(
                rec=rec,
                tokenizer=tokenizer,
                model=model,
                device=device,
                ctx_len=ctx_len,
                initial_cap=args.initial_cap,
                shrink_factor=args.shrink_factor
            )
            if out_rec is None:
                continue
            batch.append(out_rec)
            if len(batch) >= batch_cap:
                append_jsonl(out_path, batch)
                written += len(batch)
                batch.clear()
        except Exception as e:
            sys.stderr.write(f"[error] Failed on a record: {e}\n")

    if batch:
        append_jsonl(out_path, batch)
        written += len(batch)

    sys.stderr.write(f"[info] Wrote {written} records -> {out_path}\n")


if __name__ == "__main__":
    main()
