import numpy as np
import torch
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional, Any

# ---------- numerics ----------
LOGTINY = -1e12
EPS = 1e-8

def logsumexp(a, axis=None, keepdims=False):
    a = np.asarray(a)
    m = np.max(a, axis=axis, keepdims=True)
    out = m + np.log(np.maximum(EPS, np.sum(np.exp(a - m), axis=axis, keepdims=True)))
    if not keepdims:
        out = np.squeeze(out, axis=axis)
    return out

# ---------- semantic tag space (anchored top) ----------
CANON_TAGS = [
    "final_answer",                 # 0
    "setup_and_retrieval",          # 1
    "analysis_and_computation",     # 2
    "uncertainty_and_verification", # 3
    "unknown",                      # 4
]
CANON_TAG2ID = {t: i for i, t in enumerate(CANON_TAGS)}

def _unwrap1(x: Any) -> Any:
    if isinstance(x, (tuple, list)) and len(x) > 0:
        return x[0]
    return x

def _label_str_to_canon_id(s: str) -> int:
    if s in CANON_TAG2ID:
        return CANON_TAG2ID[s]
    raise KeyError(f"Unknown label string: {s!r}. Valid tags: {list(CANON_TAG2ID.keys())}")

def coerce_labels_to_ids(labels: List[Any]) -> List[int]:
    out: List[int] = []
    for v in labels:
        # Handle (text, label) tuples
        if isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[1], (str, int, float, np.integer, np.floating)):
            v = v[1]
        else:
            v = _unwrap1(v)
        
        # Convert to int
        if isinstance(v, str):
            out.append(_label_str_to_canon_id(v))
        elif isinstance(v, (int, np.integer, float, np.floating)):
            iv = int(v)
            # FIXED: Changed from 0-3 to 0-4
            if iv < 0 or iv > 4:
                raise ValueError(
                    f"Numeric label {v} out of range. Must be in {{0,1,2,3,4}}. "
                    f"Valid labels: {CANON_TAGS}"
                )
            out.append(iv)
        else:
            raise TypeError(f"Unsupported label type: {type(v)} value={v}")
    return out

def extract_labels_from_sentences_with_labels(swls: List[Any]) -> List[int]:
    """Extract and convert labels from sentences_with_labels field."""
    if not isinstance(swls, list):
        raise TypeError("sentences_with_labels must be a list.")
    return coerce_labels_to_ids(swls)

# ---------- data adapters ----------
def load_pt_records(pt_path: str):
    """Load records from a .pt file generated by DP inference code."""
    blob = torch.load(pt_path, map_location="cpu")
    return blob["records"]

def build_top_sequences(records) -> List[Dict]:
    seqs: List[Dict] = []
    for r in records:
        hs_list = r.get("step_hidden_states", [])
        steps: List[np.ndarray] = []
        for H in hs_list:
            if H is None:
                continue
            if isinstance(H, torch.Tensor):
                arr = H.to(torch.float32).cpu().numpy()  # [num_layers, D]
            else:
                arr = torch.as_tensor(H, dtype=torch.float32).cpu().numpy()
            steps.append(arr)
        if not steps:
            continue
        
        seq: Dict[str, Any] = {"steps": steps}
        
        # Add labels if available (required for anchored training)
        if "sentences_with_labels" in r:
            seq["sentences_with_labels"] = r["sentences_with_labels"]
        
        seqs.append(seq)
    return seqs

# ---------- bottom HMM (per-category) with diagonal Gaussian emissions ----------
@dataclass
class BottomHMMParams:
    """Parameters for one bottom-level HMM (one semantic category)."""
    startprob: np.ndarray      # [K] - initial state probabilities
    transmat: np.ndarray       # [K, K] - state transition matrix
    means: np.ndarray          # [K, D] - emission means
    variances: np.ndarray      # [K, D] - diagonal covariances

def init_bottom_params(K: int, D: int, x_samples: np.ndarray, sticky=0.90, rs: Optional[np.random.RandomState]=None) -> BottomHMMParams:
    """
    Initialize bottom HMM parameters deterministically.
    
    Args:
        K: Number of regimes (bottom states)
        D: Hidden dimension
        x_samples: Sample data [N, D] for initialization
        sticky: Self-transition probability (default 0.9)
        rs: Random state for deterministic initialization
    """
    rs = rs or np.random.RandomState(0)
    pool_n = min(10 * K, x_samples.shape[0])
    idx = rs.choice(x_samples.shape[0], size=pool_n, replace=False)
    base = x_samples[idx]
    means = base[rs.choice(base.shape[0], size=K, replace=True)]
    variances = np.tile(np.var(base, axis=0, ddof=1) + 1e-2, (K, 1))
    startprob = np.ones(K, dtype=np.float64) / max(K, 1)
    transmat = np.full((K, K), (1.0 - sticky) / max(K - 1, 1), dtype=np.float64)
    np.fill_diagonal(transmat, sticky)
    return BottomHMMParams(startprob, transmat, means, variances)

def log_gaussian_diag(x: np.ndarray, means: np.ndarray, variances: np.ndarray) -> np.ndarray:
    T, D = x.shape
    inv = 1.0 / np.maximum(variances, 1e-6)
    logdet = 0.5 * np.sum(np.log(np.maximum(variances, 1e-6)), axis=1)
    diff = x[:, None, :] - means[None, :, :]
    quad = 0.5 * np.sum(diff * diff * inv[None, :, :], axis=2)
    const = 0.5 * D * np.log(2 * np.pi)
    return -(const + quad + logdet[None, :])

def fb_bottom(x: np.ndarray, params: BottomHMMParams):
    start = np.log(np.maximum(EPS, params.startprob))
    trans = np.log(np.maximum(EPS, params.transmat))
    B = log_gaussian_diag(x, params.means, params.variances)

    T, K = B.shape
    alpha = np.full((T, K), LOGTINY)
    alpha[0] = start + B[0]
    for t in range(1, T):
        alpha[t] = B[t] + logsumexp(alpha[t-1][:, None] + trans, axis=0)
    logZ = logsumexp(alpha[-1], axis=0)

    beta = np.full((T, K), LOGTINY)
    beta[-1] = 0.0
    for t in range(T - 2, -1, -1):
        beta[t] = logsumexp(trans + (B[t + 1] + beta[t + 1])[None, :], axis=1)

    gamma = alpha + beta - logZ
    gamma = np.exp(gamma)
    xi = np.zeros((K, K), dtype=np.float64)
    for t in range(T - 1):
        log_psi = alpha[t][:, None] + trans + B[t + 1][None, :] + beta[t + 1][None, :] - logZ
        xi += np.exp(log_psi)
    return logZ, gamma, xi

def mstep_bottom(accum, min_var=1e-3) -> BottomHMMParams:
    """M-step for bottom HMM."""
    start = accum["start"]
    trans = accum["trans"]
    sum_w = accum["sum_w"]
    sum_x = accum["sum_x"]
    sum_x2 = accum["sum_x2"]

    startprob = np.maximum(EPS, start)
    startprob = startprob / np.sum(startprob) if startprob.sum() > 0 else np.ones_like(startprob) / len(startprob)

    trans = np.maximum(EPS, trans)
    trans = trans / np.maximum(EPS, np.sum(trans, axis=1, keepdims=True))

    means = sum_x / np.maximum(EPS, sum_w)[:, None]
    variances = sum_x2 / np.maximum(EPS, sum_w)[:, None] - means * means
    variances = np.maximum(variances, min_var)
    return BottomHMMParams(startprob, trans, means, variances)

# ---------- top HMM (per-sequence steps) ----------
@dataclass
class TopHMMParams:
    """Parameters for top-level HMM (semantic categories)."""
    startprob: np.ndarray   # [C] - initial category probabilities
    transmat: np.ndarray    # [C, C] - category transition matrix

def init_top_params(C: int, sticky=0.80):
    """Initialize top HMM parameters."""
    startprob = np.ones(C, dtype=np.float64) / max(C, 1)
    trans = np.full((C, C), (1.0 - sticky) / max(C - 1, 1), dtype=np.float64)
    np.fill_diagonal(trans, sticky)
    return TopHMMParams(startprob, trans)

def fb_top(log_emissions: np.ndarray, params: TopHMMParams):
    """Forward-Backward for top HMM (not used in anchored version)."""
    start = np.log(np.maximum(EPS, params.startprob))
    trans = np.log(np.maximum(EPS, params.transmat))
    T, C = log_emissions.shape
    alpha = np.full((T, C), LOGTINY)
    alpha[0] = start + log_emissions[0]
    for t in range(1, T):
        alpha[t] = log_emissions[t] + logsumexp(alpha[t - 1][:, None] + trans, axis=0)
    logZ = logsumexp(alpha[-1], axis=0)

    beta = np.full((T, C), LOGTINY)
    beta[-1] = 0.0
    for t in range(T - 2, -1, -1):
        beta[t] = logsumexp(trans + (log_emissions[t + 1] + beta[t + 1])[None, :], axis=1)

    gamma = np.exp(alpha + beta - logZ)
    xi = np.zeros((C, C), dtype=np.float64)
    for t in range(T - 1):
        log_psi = alpha[t][:, None] + trans + (log_emissions[t + 1] + beta[t + 1])[None, :] - logZ
        xi += np.exp(log_psi)
    return gamma, xi, logZ

# ---------- HHMM container ----------
@dataclass
class HHMM:
    C: int                              # Number of top-level categories
    K: int                              # Number of bottom-level regimes per category
    D: int                              # Hidden dimension
    top: TopHMMParams                   # Top-level parameters
    bottom: List[BottomHMMParams]       # Bottom-level parameters (length C)

def init_hhmm(C: int, K: int, D: int, sequences, rs: Optional[np.random.RandomState]=None) -> HHMM:
    """Initialize HHMM parameters."""
    rs = rs or np.random.RandomState(0)
    pool: List[np.ndarray] = []
    for seq in sequences:
        for x in seq["steps"]:
            pool.append(x)
    if not pool:
        raise ValueError("No steps to initialize from.")
    X = np.concatenate(pool, axis=0)
    top = init_top_params(C)
    bottom = [init_bottom_params(K, D, X, sticky=0.90, rs=rs) for _ in range(C)]
    return HHMM(C, K, D, top, bottom)

# ---------- EM training (anchored top with semantic labels) ----------
def fit_hhmm_fixed_top(
    sequences: List[Dict],
    C: int,
    K: int,
    label_key: str = "sentences_with_labels",
    n_iter: int = 10,
    min_var: float = 1e-3,
    seed: int = 0,
    verbose: bool = True,
) -> HHMM:
    # Infer D from first sequence
    D: Optional[int] = None
    for seq in sequences:
        for x in seq["steps"]:
            D = int(x.shape[1])
            break
        if D is not None:
            break
    assert D is not None, "Empty sequences."

    rs = np.random.RandomState(seed)
    model = init_hhmm(C, K, D, sequences, rs)

    for it in range(1, n_iter + 1):
        top_start_counts = np.zeros(C, dtype=np.float64)
        top_trans_counts = np.zeros((C, C), dtype=np.float64)
        bottom_acc = [{
            "start": np.zeros(K),
            "trans": np.zeros((K, K)),
            "sum_w": np.zeros(K),
            "sum_x": np.zeros((K, D)),
            "sum_x2": np.zeros((K, D)),
        } for _ in range(C)]
        total_logZ = 0.0
        n_steps_total = 0

        for seq in sequences:
            steps = seq["steps"]

            # Get anchored labels
            labels_raw = seq.get(label_key, None)
            if labels_raw is None:
                raise RuntimeError(
                    f"Anchored training requires labels: missing '{label_key}' in a sequence. "
                    f"This implementation only supports supervised/anchored HHMM."
                )
            
            y = coerce_labels_to_ids(labels_raw)  # Convert to 0..C-1
            
            # Validation
            if len(seq["steps"]) != len(labels_raw):
                raise ValueError(
                    f"Step/label count mismatch: len(steps)={len(seq['steps'])} "
                    f"but len({label_key})={len(labels_raw)}"
                )
            
            Tsteps = min(len(steps), len(y))
            if Tsteps == 0:
                continue
            n_steps_total += Tsteps

            # Count top-level transitions (from anchored labels)
            top_start_counts[y[0]] += 1.0
            for t in range(Tsteps - 1):
                top_trans_counts[y[t], y[t + 1]] += 1.0

            # E-step for bottom HMMs (conditioned on fixed category)
            for t in range(Tsteps):
                c = int(y[t])
                x = steps[t]  # [num_layers, D]
                logZ, gamma_r, xi_r = fb_bottom(x, model.bottom[c])
                total_logZ += logZ

                # Accumulate sufficient statistics
                bottom_acc[c]["start"] += gamma_r[0]
                bottom_acc[c]["trans"] += xi_r
                gsum = np.sum(gamma_r, axis=0)
                bottom_acc[c]["sum_w"] += gsum
                bottom_acc[c]["sum_x"] += (gamma_r.T @ x)
                bottom_acc[c]["sum_x2"] += (gamma_r.T @ (x * x))

        # M-step for top (from counts)
        startprob = np.maximum(EPS, top_start_counts)
        startprob = startprob / np.sum(startprob) if startprob.sum() > 0 else np.ones(C) / C

        transmat = np.maximum(EPS, top_trans_counts)
        row_sums = np.sum(transmat, axis=1, keepdims=True)
        transmat = np.divide(transmat, np.maximum(EPS, row_sums), 
                            out=np.ones_like(transmat)/C, where=row_sums > 0)
        model.top = TopHMMParams(startprob, transmat)

        # M-step for bottoms
        for c in range(C):
            model.bottom[c] = mstep_bottom(bottom_acc[c], min_var=min_var)

        if verbose:
            avg_ll = total_logZ / max(1, n_steps_total)
            print(f"[HHMM-anchored] iter {it:02d}  avg step loglik = {avg_ll:.4f}", flush=True)

    return model

# ---------- decoding (anchored only) ----------
def viterbi_bottom(x: np.ndarray, params: BottomHMMParams) -> Tuple[float, np.ndarray]:
    """Viterbi decoding for bottom HMM."""
    start = np.log(np.maximum(EPS, params.startprob))
    trans = np.log(np.maximum(EPS, params.transmat))
    B = log_gaussian_diag(x, params.means, params.variances)
    T, K = B.shape
    dp = np.full((T, K), LOGTINY)
    ptr = np.full((T, K), -1, dtype=int)
    dp[0] = start + B[0]
    for t in range(1, T):
        scores = dp[t - 1][:, None] + trans
        ptr[t] = np.argmax(scores, axis=0)
        dp[t] = B[t] + np.max(scores, axis=0)
    logZ = np.max(dp[-1])
    z = np.zeros(T, dtype=int)
    z[-1] = int(np.argmax(dp[-1]))
    for t in range(T - 2, -1, -1):
        z[t] = ptr[t + 1, z[t + 1]]
    return logZ, z

def decode_hhmm_anchored(sequences: List[Dict], model: HHMM, 
                         label_key: str = "sentences_with_labels"):
    out = []
    for seq in sequences:
        steps = seq["steps"]
        labels_raw = seq.get(label_key, None)
        if labels_raw is None:
            raise RuntimeError(
                f"Anchored decode requires labels: missing '{label_key}'. "
                f"This implementation only supports supervised/anchored HHMM."
            )
        
        y = coerce_labels_to_ids(labels_raw)
        
        if len(seq["steps"]) != len(labels_raw):
            raise ValueError(
                f"Length mismatch: len(steps)={len(seq['steps'])} "
                f"!= len({label_key})={len(labels_raw)}."
            )
        
        T = min(len(steps), len(y))
        cats = []
        regimes = []
        
        for t in range(T):
            c = int(y[t])
            x = steps[t]
            _, z_r = viterbi_bottom(x, model.bottom[c])
            cats.append(c)
            regimes.append(z_r.tolist())
        
        out.append({
            "best_categories": cats,
            "best_regimes_per_step": regimes
        })
    return out