# datagen/languages.py
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Sequence, Optional
import random
import json
import pathlib
import pdb
from automata.dfa import DFA
from datagen.pfa import PDFA

EOS_TOKEN = "[EOS]"
BOS_TOKEN = "[BOS]"

class Language(ABC):
    name: str
    sigma: List[str]

    @property
    def eos(self) -> str: return EOS_TOKEN
    @property
    def bos(self) -> str: return BOS_TOKEN

    @abstractmethod
    def sample_positive(self, min_len: int, max_len: int, rng: random.Random) -> List[str]:
        ...
    @abstractmethod
    def is_positive(self, tokens: Sequence[str]) -> bool:
        ...

    def validate_tokens(self, tokens: Sequence[str]) -> None:
        bad = [t for t in tokens if t not in self.sigma]
        if bad:
            raise ValueError(f"Tokens not in Σ: {bad}; Σ={self.sigma}")

def load_dfa_json(path: str | pathlib.Path) -> DFA:
    obj = json.loads(pathlib.Path(path).read_text())
    return DFA(
        sigma=tuple(obj["sigma"]),
        start=int(obj["start"]),
        finals=tuple(int(x) for x in obj["finals"]),
        delta=tuple(tuple(int(x) for x in row) for row in obj["delta"]),
        dead=(None if obj["dead"] is None else int(obj["dead"])),
    )

@dataclass
class PDFALanguage(Language):
    """
    Language implemented by a DFA loaded from disk and a PDFA sampler
    with uniform admissible emissions and a given termination probability.
    """
    name: str
    dfa_path: str
    term_prob: float = 0.2  # termination probability at accepting states
    _dfa: Optional[DFA] = None
    _pdfa: Optional[PDFA] = None
    skip_first: bool = True  # whether to skip the first empty string sample
    oname = None

    def __post_init__(self):
        dfa = load_dfa_json(self.dfa_path)
        pdfa = PDFA.from_dfa(dfa, final_hazard=self.term_prob)
        object.__setattr__(self, "_dfa", dfa)
        object.__setattr__(self, "_pdfa", pdfa)
        object.__setattr__(self, "sigma", list(dfa.sigma))
        object.__setattr__(self, "skip_first", self.skip_first)

    def sample_positive(self, max_len: int, rng: random.Random) -> List[str]:
        # ε allowed; ignore min_len by design.
        assert self._pdfa is not None
        tries = 0
        first_zero = True # slightly bias against empty string if skip_first

        while True:
            s = self._pdfa.sample(rng)
            if first_zero and len(s) == 0:
                first_zero = False
                if self.skip_first:
                    tries += 1
                    continue

            if len(s) <= max_len:
                return s
            tries += 1
            if tries > 200:
                pdb.set_trace()
                raise RuntimeError("Exceeded max retries enforcing max_len; consider increasing term_prob.")
                # return None

    def is_positive(self, tokens: Sequence[str]) -> bool:
        self.validate_tokens(tokens)
        assert self._dfa is not None
        return self._dfa.accepts(tokens)




@dataclass
class Dyck22(PDFALanguage):
    name: str = "dyck22"
    oname: str = "Dyck-(2,2)"
    dfa_path: str = "datagen/assets/dfas/dyck22/dyck22.dfa.json"
    term_prob: float = 0.25
    min_len: int = 0


@dataclass
class Dyck33(PDFALanguage):
    name: str = "dyck33"
    oname: str = "Dyck-(3,3)"
    dfa_path: str = "datagen/assets/dfas/dyck33/dyck33.dfa.json"
    term_prob: float = 0.6
    min_len: int = 0

@dataclass
class Dyck24(PDFALanguage):
    name: str = "dyck24"
    oname: str = "Dyck-(2,4)"
    dfa_path: str = "datagen/assets/dfas/dyck24/dyck24.dfa.json"
    term_prob: float = 0.55
    min_len: int = 0


@dataclass
class Dyck43(PDFALanguage):
    name: str = "dyck43"
    oname: str = "Dyck-(4,3)"
    dfa_path: str = "datagen/assets/dfas/dyck43/dyck43.dfa.json"
    term_prob: float = 0.6
    min_len: int = 0


@dataclass
class Parity(PDFALanguage):
    name: str = "parity"
    oname: str = "Parity"
    dfa_path: str = "datagen/assets/dfas/parity/parity.dfa.json"
    term_prob: float = 0.2 # 0.1 -> 20 , 0.15 -> 13, 0.2 -> 10, 0.25 -> 8
    min_len: int = 1



@dataclass
class Tomita2(PDFALanguage):
    name: str = "tomita2"
    oname: str = "Tomita-2"
    dfa_path: str = "datagen/assets/dfas/tomita2/tomita2.dfa.json"
    term_prob: float = 0.125
    min_len: int = 0


@dataclass
class Tomita3(PDFALanguage):
    name: str = "tomita3"
    oname: str = "Tomita-3"
    dfa_path: str = "datagen/assets/dfas/tomita3/tomita3.dfa.json"
    term_prob: float = 0.09 #  mean_len:  0.09 --> 13
    min_len: int = 0
    skip_first: bool = False


@dataclass
class Tomita4(PDFALanguage):
    name: str = "tomita4"
    oname: str = "Tomita-4"
    dfa_path: str = "datagen/assets/dfas/tomita4/tomita4.dfa.json"
    term_prob: float = 0.05
    min_len: int = 0
    skip_first: bool = False
    


@dataclass
class Tomita5(PDFALanguage):
    name: str = "tomita5"
    oname: str = "Tomita-5"
    dfa_path: str = "datagen/assets/dfas/tomita5/tomita5.dfa.json"
    term_prob: float = 0.25 #  mean_len: 0.25 --> 15
    min_len: int = 0


@dataclass
class Tomita6(PDFALanguage):
    name: str = "tomita6"
    oname: str = "Tomita-6"
    dfa_path: str = "datagen/assets/dfas/tomita6/tomita6.dfa.json"
    term_prob: float = 0.17 #  mean_len:  0.15 --> 17, 0.125 --> 20
    min_len: int = 0
    skip_first: bool = False

@dataclass
class Tomita7(PDFALanguage):
    name: str = "tomita7"
    oname: str = "Tomita-7"
    dfa_path: str = "datagen/assets/dfas/tomita7/tomita7.dfa.json"
    term_prob: float = 0.05 #  mean_len: 0.05 --> 19
    min_len: int = 0
    skip_first: bool = False



LANGUAGES = {
    "dyck22": Dyck22(),
    "dyck33": Dyck33(),
    "dyck24": Dyck24(),
    "dyck43": Dyck43(),
    "parity": Parity(),
    "tomita2": Tomita2(),
    "tomita3": Tomita3(),
    "tomita4": Tomita4(),
    "tomita5": Tomita5(),
    "tomita6": Tomita6(),
    "tomita7": Tomita7(),
}
