import json
import os
from dataclasses import dataclass, field
from typing import Optional

import torch

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    set_seed,
)


@dataclass
class SampleArguments:
    model_name_or_path: str = field(metadata={"help": "Model path or name"})
    output_file: str = field(default="sampled_texts.jsonl", metadata={"help": "Output path for sampled texts"})
    num_train_steps: int = field(default=10000)
    per_device_train_batch_size: int = field(default=8)
    seq_len: int = field(default=32)
    temperature: float = field(default=1.0)
    seed: int = field(default=42)
    device: str = field(default="cuda", metadata={"help": "cuda or cpu"})
    cache_dir: Optional[str] = field(default=None, metadata={"help": "Directory to cache the model"})
    top_k: Optional[int] = field(default=100, metadata={"help": "Top-k sampling (used if do_sample=True)"})


def main():
    parser = HfArgumentParser(SampleArguments)
    args = parser.parse_args_into_dataclasses()[0]

    set_seed(args.seed)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        cache_dir=args.cache_dir
    )
    tokenizer.padding_side = "left"
    tokenizer.init_kwargs["padding_side"] = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.float16 if args.device == "cuda" else torch.float32,
        device_map="auto" if args.device == "cuda" else None,
        cache_dir=args.cache_dir
    )
    model.eval()

    total_samples = args.num_train_steps * args.per_device_train_batch_size
    batch_size = args.per_device_train_batch_size
    bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id

    print(f"Sampling {total_samples} examples in batches of {batch_size}...")

    os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
    with open(args.output_file, "w", encoding="utf-8") as f:
        for batch_start in range(0, total_samples, batch_size):
            actual_batch_size = min(batch_size, total_samples - batch_start)
            input_ids = torch.full(
                (actual_batch_size, 1),
                fill_value=bos_token_id,
                dtype=torch.long
            ).to(args.device)

            with torch.no_grad():
                outputs = model.generate(
                    input_ids=input_ids,
                    max_length=args.seq_len + 1,
                    do_sample=True,
                    temperature=args.temperature,
                    top_k=args.top_k,
                    pad_token_id=tokenizer.pad_token_id,
                )

            for i in range(actual_batch_size):
                ids = outputs[i, 1:]
                text = tokenizer.decode(ids, skip_special_tokens=True)
                f.write(json.dumps({
                    "input_ids": ids.tolist(),
                    "text": text
                }, ensure_ascii=False) + "\n")

            if (batch_start // batch_size) % 10 == 0:
                print(f"Sampled {batch_start + actual_batch_size}/{total_samples}")

    print(f"Sampling complete. Saved to: {args.output_file}")


if __name__ == "__main__":
    main()
