# datagen/pfa.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
import json
import pathlib
import random

import numpy as np
import matplotlib.pyplot as plt  # used only in plot_length_histogram

from automata.dfa import DFA

Symbol = str


@dataclass
class PDFA:
    """
    Probabilistic DFA (PDFA) over a canonical DFA.

    Semantics per step at state q:
      - If q is accepting: with probability hazard[q], STOP (EOS).
        Otherwise (prob 1 - hazard[q]) emit ONE symbol σ with probability symbol_probs[q,σ]
        and move via δ(q,σ).
      - If q is not accepting: emit ONE symbol according to symbol_probs[q,σ] and move.

    Invariants:
      - symbol_probs[q,σ] == 0 whenever δ(q,σ) is the dead sink.
      - For every non-dead state q, sum_σ symbol_probs[q,σ] == 1. (Dead row is all zeros.)
      - hazard[q] ∈ [0,1) and hazard[q] > 0 only for accepting states.
    """

    dfa: DFA
    symbol_probs: np.ndarray    # shape (n_states, |Σ|)
    hazard: np.ndarray          # shape (n_states,)

    # ---------- construction ----------
    @classmethod
    def from_dfa(
        cls,
        dfa: DFA,
        *,
        final_hazard: float = 0.1,
        per_final_hazard: Optional[Dict[int, float]] = None,
        symbol_probs: Optional[np.ndarray] = None,
    ) -> "PDFA":
        """
        Build a PDFA from a DFA.
          - If symbol_probs is None: assign UNIFORM mass over admissible symbols (δ(q,σ) != dead).
          - Termination probabilities: either a scalar final_hazard for all finals, or a dict per_final_hazard.
        """
        n, m = dfa.n_states, dfa.alphabet_size
        dead = dfa.dead

        # Symbol probabilities (default: uniform over admissible)
        if symbol_probs is None:
            probs = np.zeros((n, m), dtype=float)
            for q in range(n):
                if dead is not None and q == dead:
                    continue
                admiss = [a for a in range(m) if (dead is None or dfa.delta[q][a] != dead)]
                if not admiss:
                    # Only the dead state should have zero admissible symbols in a trimmed/minimal DFA
                    raise ValueError(f"State {q} has no admissible outgoing symbols.")
                w = 1.0 / len(admiss)
                for a in admiss:
                    probs[q, a] = w
        else:
            probs = np.array(symbol_probs, dtype=float)
            if probs.shape != (n, m):
                raise ValueError(f"symbol_probs must have shape {(n, m)}, got {probs.shape}")

        # Zero out any probability placed on dead-leading symbols; renormalize rows.
        if dead is not None:
            for q in range(n):
                if q == dead:
                    probs[q, :] = 0.0
                    continue
                for a in range(m):
                    if dfa.delta[q][a] == dead:
                        probs[q, a] = 0.0
                s = probs[q, :].sum()
                if s <= 0.0:
                    raise ValueError(f"Row {q} has zero probability after removing dead transitions.")
                probs[q, :] /= s

        # Hazards (termination at accepting states)
        hz = np.zeros(n, dtype=float)
        finals = set(dfa.finals)
        if per_final_hazard is not None:
            for q, h in per_final_hazard.items():
                if q not in finals:
                    raise ValueError(f"hazard specified for non-final state {q}")
                if not (0.0 <= h < 1.0):
                    raise ValueError(f"hazard[{q}]={h} not in [0,1).")
                hz[q] = float(h)
        else:
            if not (0.0 <= final_hazard < 1.0):
                raise ValueError("final_hazard must be in [0,1).")
            for q in finals:
                hz[q] = float(final_hazard)

        return cls(dfa=dfa, symbol_probs=probs, hazard=hz)

    @property
    def sigma(self) -> Tuple[Symbol, ...]:
        return self.dfa.sigma

    def set_uniform_final_hazard(self, h: float) -> None:
        if not (0.0 <= h < 1.0):
            raise ValueError("h must be in [0,1).")
        self.hazard[:] = 0.0
        for q in self.dfa.finals:
            self.hazard[q] = float(h)

    # ---------- sampling ----------
    def sample(self, rng: random.Random, *, max_steps: int = 1_000_000) -> List[Symbol]:
        """
        Sample a positive string (no EOS token included). ε is allowed.
        """
        dfa = self.dfa
        dead = dfa.dead
        finals = set(dfa.finals)
        tokens: List[Symbol] = []
        q = dfa.start

        # Precompute admissible columns for speed
        adm_cols: List[List[int]] = [[] for _ in range(dfa.n_states)]
        for qq in range(dfa.n_states):
            if dead is not None and qq == dead:
                continue
            row = [a for a in range(dfa.alphabet_size)
                   if (dead is None or dfa.delta[qq][a] != dead) and self.symbol_probs[qq, a] > 0.0]
            adm_cols[qq] = row

        for _ in range(max_steps):
            # Option to stop at accepting states
            if q in finals:
                h = float(self.hazard[q])
                if h > 0.0 and rng.random() < h:
                    return tokens

            # Otherwise emit one symbol
            cols = adm_cols[q]
            if not cols:
                if dead is not None and q == dead:
                    raise RuntimeError("Reached dead during PDFA sampling (should be impossible).")
                raise RuntimeError(f"No admissible emissions from state {q}.")

            u = rng.random()
            c = 0.0
            pick = cols[-1]
            for a in cols:
                c += float(self.symbol_probs[q, a])
                if u <= c:
                    pick = a
                    break

            tokens.append(dfa.sigma[pick])
            q = dfa.delta[q][pick]

        raise RuntimeError("PDFA.sample exceeded max_steps; check hazards and DFA structure.")

    def estimate_mean(self, N: int, skip_first =False, rng: Optional[random.Random] = None) -> float:
        """
        Monte-Carlo estimate of E[length] by sampling N strings.
        """
        if rng is None:
            rng = random.Random(0)
        total = 0
        for _ in range(N):
            s = self.sample(rng)
            if len(s) == 0 and skip_first: 
                s = self.sample(rng)
            total += len(s)
        return total / float(N)

    def expected_length_theory(self) -> float:
        """
        Theoretical E[length] via t = b + M t, with
          M = diag(1 - hazard) @ T_symbol,
          b[q] = 1 (q not final), b[q] = 1 - hazard[q] (q final).
        """
        n, m = self.dfa.n_states, self.dfa.alphabet_size

        # T_symbol[q, r] = P(move to r | we emit one symbol at q)
        T = np.zeros((n, n), dtype=float)
        for q in range(n):
            for a in range(m):
                r = self.dfa.delta[q][a]
                T[q, r] += float(self.symbol_probs[q, a])

        finals = set(self.dfa.finals)
        one_minus_h = 1.0 - self.hazard.astype(float)
        M = (one_minus_h[:, None] * T)
        b = np.ones(n, dtype=float)
        for q in finals:
            b[q] = one_minus_h[q]

        I = np.eye(n)
        try:
            t = np.linalg.solve(I - M, b)
        except np.linalg.LinAlgError:
            t = np.linalg.lstsq(I - M, b, rcond=None)[0]
        return float(t[self.dfa.start])

    def plot_length_histogram(
        self,
        N: int,
        skip_first =False,
        *,
        rng: Optional[random.Random] = None,
        save_path: Optional[str | pathlib.Path] = None,
        show: bool = True,
        title: Optional[str] = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sample N strings, build a DISCRETE histogram (PMF) of lengths, and either save or display it.
        Returns (length_values, probabilities).

        - For discrete lengths, we prefer a bar plot over plt.hist binning.
        - If save_path is provided, the figure is written there; 'show' controls whether plt.show() is called.
        """
        if rng is None:
            rng = random.Random(0)

        samples = []
        for i in range(N):
            s = self.sample(rng)
            if len(s) == 0 and skip_first:
                s = self.sample(rng)
            samples.append(len(s))
        
        lengths= np.array(samples)

        # lengths = np.fromiter((len(self.sample(rng)) for _ in range(N)), dtype=int, count=N)
        values, counts = np.unique(lengths, return_counts=True)
        probs = counts.astype(float) / float(N)

        fig, ax = plt.subplots()
        ax.bar(values, probs, width=0.9, align="center", edgecolor="black")
        ax.set_xlabel("String length")
        ax.set_ylabel("Probability")
        ttl = title or f"Length distribution (N={N})"
        ax.set_title(ttl)
        ax.grid(True, axis="y", linestyle="--", alpha=0.3)

        if save_path is not None:
            save_path = pathlib.Path(save_path)
            save_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(save_path, bbox_inches="tight")
        if show and save_path is None:
            plt.show()
        plt.close(fig)
        # return values, probs

    # ---------- optional persistence (handy for caching) ----------
    def to_json(self) -> Dict[str, Any]:
        return {
            "dfa": {
                "sigma": list(self.dfa.sigma),
                "start": int(self.dfa.start),
                "finals": list(map(int, self.dfa.finals)),
                "delta": [list(map(int, row)) for row in self.dfa.delta],
                "dead": (None if self.dfa.dead is None else int(self.dfa.dead)),
            },
            "symbol_probs": self.symbol_probs.tolist(),
            "hazard": self.hazard.tolist(),
        }

    @staticmethod
    def from_json(obj: Dict[str, Any]) -> "PDFA":
        d = obj["dfa"]
        dfa = DFA(
            sigma=tuple(d["sigma"]),
            start=int(d["start"]),
            finals=tuple(int(x) for x in d["finals"]),
            delta=tuple(tuple(int(x) for x in row) for row in d["delta"]),
            dead=(None if d["dead"] is None else int(d["dead"])),
        )
        probs = np.array(obj["symbol_probs"], dtype=float)
        haz = np.array(obj["hazard"], dtype=float)
        return PDFA(dfa=dfa, symbol_probs=probs, hazard=haz)

    def save(self, path: str | pathlib.Path) -> None:
        p = pathlib.Path(path)
        p.parent.mkdir(parents=True, exist_ok=True)
        p.write_text(json.dumps(self.to_json(), ensure_ascii=False, indent=2))

    @staticmethod
    def load(path: str | pathlib.Path) -> "PDFA":
        p = pathlib.Path(path)
        return PDFA.from_json(json.loads(p.read_text()))
