from __future__ import annotations

from typing import Optional, Dict, Any, List

import random
import re
from datasets import load_dataset, Dataset, concatenate_datasets
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase

# 改成「續寫故事」的任務模板（放在 user 內容中）
QUERY_TEMPLATE = """You are a writing assistant. Continue the following story in a coherent, engaging, and stylistically consistent way.

Story:
{context}

Continue the story:"""

# ---------- sentence boundary (optional) ----------
_SENT_END_RE = re.compile(r"[.!?]\s*[A-Z]")  # simple English heuristic


def _find_last_sentence_end(text: str) -> int:
    matches = list(_SENT_END_RE.finditer(text))
    if not matches:
        return len(text)
    m = matches[-1]
    return m.start() + 1


def load_pg19_dataset(
    tokenizer: PreTrainedTokenizerBase,
    input_len: int,
    *,
    split: str | None = "all",
    max_samples: Optional[int] = None,
    seed: int = 42,
    buffer_tokens: int = 200,
    sentence_align: bool = True,
    min_source_tokens: Optional[int] = None,
    start_offset_tokens: int = 0,
) -> Dataset:
    if input_len <= 0:
        raise ValueError(f"input_len must be positive, got {input_len}")
    if split is None or split == "all":

        ds = concatenate_datasets([
            load_dataset("emozilla/pg19", split="test"),
            load_dataset("emozilla/pg19", split="validation"),
            load_dataset("emozilla/pg19", split="train"),
        ])
    else:
        ds = load_dataset("emozilla/pg19", split=split)

    # ---- Estimate chat template overhead in tokens ----
    # 這裡的 system_prompt 仍然是「chat system message」的內容，
    # 而 QUERY_TEMPLATE 會被塞到 user 的內容中（下面會套用）。
    system_prompt = (
        "You are a creative writing assistant. Continue the following story "
        "in a coherent, engaging, and stylistically consistent way."
    )

    if not hasattr(tokenizer, "apply_chat_template"):
        raise ValueError("Tokenizer has no apply_chat_template; cannot match your build_chat_pg19 usage.")

    empty_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": ""},
    ]
    empty_prompt = tokenizer.apply_chat_template(
        empty_messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    overhead_ids = tokenizer(
        [empty_prompt],
        return_tensors=None,
        add_special_tokens=False,
        truncation=False,
    )["input_ids"][0]
    chat_overhead = len(overhead_ids)

    # ---- Decide how many tokens of raw book text we need to reach input_len ----
    if input_len <= chat_overhead + 16:
        raise ValueError(
            f"input_len={input_len} is too small relative to chat template overhead={chat_overhead}. "
            f"Use a larger input_len."
        )

    # 重要：現在 user 內容不是純 sliced_text，而是 QUERY_TEMPLATE.format(context=...)
    # 所以要把模板本身的 token 也算進去，否則會超過 input_len。
    template_prefix = QUERY_TEMPLATE.format(context="")
    template_prefix_tokens = len(tokenizer.encode(template_prefix, add_special_tokens=False, truncation=False))

    # 你的最終 prompt tokens ≈ chat_overhead + template_prefix_tokens + len(context_tokens)
    text_target_tokens = input_len - chat_overhead - template_prefix_tokens
    if text_target_tokens <= 0:
        raise ValueError(
            f"input_len={input_len} too small after accounting for chat_overhead={chat_overhead} "
            f"and query_template_tokens={template_prefix_tokens}. Increase input_len."
        )

    initial_take = text_target_tokens + max(0, buffer_tokens)

    if min_source_tokens is None:
        min_source_tokens = start_offset_tokens + initial_take

    print(f"raw dataset size: {len(ds)} samples")
    rng = random.Random(seed)
    doc_indices = list(range(len(ds)))
    # rng.shuffle(doc_indices)

    out_rows: List[Dict[str, Any]] = []
    n_target = max_samples if max_samples is not None else len(ds)
    print(f"Building pg19 dataset with {n_target} samples, each with input_len={input_len} tokens (including chat template).")
    for doc_id in doc_indices:
        if len(out_rows) >= n_target:
            break
        text = ds[doc_id]["text"]
        doc_tokens = tokenizer.encode(text, add_special_tokens=False, truncation=False)
        source_len = len(doc_tokens)
        if source_len < min_source_tokens:
            continue

        start = min(start_offset_tokens, source_len)
        end = min(start + initial_take, source_len)
        sliced = doc_tokens[start:end]

        sliced_text = tokenizer.decode(sliced, skip_special_tokens=True)
        used_before_align = len(sliced)

        if sentence_align:
            cut_pos = _find_last_sentence_end(sliced_text)
            sliced_text = sliced_text[:cut_pos]

        # ✅ 套用 QUERY_TEMPLATE：把 context 變成「請續寫故事」的 user 內容
        user_text = QUERY_TEMPLATE.format(context=sliced_text)

        # 用套用後的 user_text 估算 prompt 長度
        prompt = tokenizer.apply_chat_template(
            [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_text},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
        prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False, truncation=False))

        if prompt_len < int(0.97 * input_len):
            continue

        out_rows.append(
            {
                "text": user_text,  # 注意：這裡回傳的是「已套模板」的 user 內容
                "target_input_len": input_len,
                "estimated_prompt_len": prompt_len,
                "doc_id": doc_id,
                "source_token_len": source_len,
                "used_tokens_before_align": used_before_align,
                "query_template_tokens": template_prefix_tokens,
            }
        )

    if not out_rows:
        raise RuntimeError(
            "Could not build any samples. Consider lowering `sentence_align`, "
            "reducing `buffer_tokens`, or increasing `input_len`."
        )

    return out_rows