
# NSP-project/datagen/generator.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import json
import random
import pathlib

from .languages import Language, EOS_TOKEN, BOS_TOKEN


@dataclass(frozen=True)
class GenConfig:
    language: Language
    max_context_len: int            # strict upper bound on total tokens in one example
    min_string_len: int
    max_string_len: int
    seed: Optional[int] = None

    def __post_init__(self):
        if self.max_context_len <= 0:
            raise ValueError("max_context_len must be > 0.")
        if not (0 <= self.min_string_len <= self.max_string_len):
            raise ValueError("Require 0 <= min_string_len <= max_string_len.")
        # At least one string must be possible: len(s) + 1 (for EOS) < max_context_len
        if self.min_string_len + 2 >= self.max_context_len:
            raise ValueError(
                "No room for BOS + one string + EOS under strict '< max_context_len' constraint. "
                "Increase max_context_len or reduce min_string_len."
            )


class DataGenerator:
    """
    Produce examples of the form s1 [EOS] s2 [EOS] ... sn [EOS]
    with total length STRICTLY LESS than max_context_len.
    Each si is a positive string sampled from the target language,
    with length in [min_string_len, max_string_len].
    """
    def __init__(self, cfg: GenConfig):
        self.cfg = cfg
        self.rng = random.Random(cfg.seed)

    @property
    def vocab(self) -> List[str]:
        # Σ ∪ {EOS, BOS}
        return list(self.cfg.language.sigma) + [EOS_TOKEN, BOS_TOKEN]

    def _sample_one_positive(self, remaining_budget_for_s: int) -> List[str]:
        """
        Sample one positive string whose length ≤ remaining_budget_for_s and ≥ min_string_len.
        """
        min_len = 0
        max_len = min(self.cfg.max_string_len, remaining_budget_for_s)
        return self.cfg.language.sample_positive(max_len, self.rng)

    def generate_example(self) -> List[str]:
        """
        Returns a single example as a flat list of tokens:
          [BOS] s1 [EOS] s2 [EOS] ... sn [EOS],
        such that total length < max_context_len.
        """
        Lmax = self.cfg.max_context_len
        tokens: List[str] = [BOS_TOKEN]
        n_strings = 0

        while True:
            # We need len(tokens) + len(s) + 1 < Lmax  =>  len(s) <= Lmax - len(tokens) - 2
            budget_for_s = Lmax - len(tokens) - 2
   
            if budget_for_s <= self.cfg.min_string_len:
                # no room for another string + EOS while staying strictly below Lmax
                if n_strings == 0:
                    raise RuntimeError(
                        "Configuration allowed but no single string fit. "
                        "This should not happen given GenConfig validation."
                    )
                break

            s = self._sample_one_positive(remaining_budget_for_s=budget_for_s)
            if s is None:
                break
            # Safety check
            if not self.cfg.language.is_positive(s):
                raise AssertionError("Generated string is not positive for the language.")
            if len(s) + 1 + len(tokens) >= Lmax:
                # should not happen due to budget_for_s logic
                raise AssertionError("Generated string exceeded remaining budget.")

            tokens.extend(s)
            tokens.append(EOS_TOKEN)
            n_strings += 1

        # Final guard: strict inequality
        assert len(tokens) < Lmax, "Example must be strictly shorter than max_context_len."
        return tokens

    def generate_many(self, n: int) -> List[List[str]]:
        examples=  [self.generate_example() for _ in range(n)]
        seqs  = [ ''.join(ex) for ex in examples]
        seq_lens = [len(ex) for ex in examples]
        avg_seq_len = sum(seq_lens)/len(seq_lens)
        print(f"Avg length of full sequence: {avg_seq_len}")
        split_seqs = [seq.split(EOS_TOKEN) for seq in seqs]
        first_seqs = [seq[0] for seq in split_seqs]
        first_lens = [len(seq) for seq in first_seqs]
        avg_len = sum(first_lens)/len(first_lens)
        print(f"Avg length of first string: {avg_len}")

        return examples, avg_len


def write_jsonl(path: str | pathlib.Path, examples: List[List[str]], meta: Optional[Dict] = None) -> None:
    path = pathlib.Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for i, toks in enumerate(examples):
            rec = {
                "id": i,
                "tokens": toks,
            }
            if meta:
                rec["meta"] = meta
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")


def write_vocab(path: str | pathlib.Path, vocab: List[str]) -> None:
    path = pathlib.Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump({"vocab": vocab}, f, ensure_ascii=False, indent=2)
