"""
Factorized Scheduling Model (FSP)

A lightweight numpy implementation of a factorized scheduling principle function:
  s(x) = b + sum_k phi_k(x_k) + sum_(i,j) psi_{ij}(x_i, x_j)

- 1D and 2D components are parameterized by piecewise-linear bases.
- Identifiability is enforced by centering components under a reference distribution.
- Optional smoothness penalties encourage stable learned shapes.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np


Array = np.ndarray
Pair = Tuple[int, int]


def _make_hat_basis(m: int) -> Tuple[Array, callable]:
    """Return knots and a callable Phi(x)->(N,m) hat-basis matrix on [0,1]."""
    if m <= 0:
        raise ValueError("m must be positive.")
    knots = np.linspace(0.0, 1.0, m)
    h = knots[1] - knots[0] if m > 1 else 1.0

    def Phi(x: Array) -> Array:
        x = np.asarray(x, dtype=float).reshape(-1)
        B = np.zeros((x.shape[0], m), dtype=float)
        for idx, c in enumerate(knots):
            d = np.abs(x - c) / h
            B[:, idx] = np.maximum(1.0 - d, 0.0)
        # normalize to sum to 1 for numerical stability in edge cases
        B /= (B.sum(axis=1, keepdims=True) + 1e-12)
        return B

    return knots, Phi


def _rowwise_kron(A: Array, B: Array) -> Array:
    """Row-wise Kronecker product: out[n] = kron(A[n], B[n])."""
    if A.shape[0] != B.shape[0]:
        raise ValueError("A and B must have the same number of rows.")
    # Vectorized construction: (N, Ma, 1) * (N, 1, Mb) -> (N, Ma, Mb) -> (N, Ma*Mb)
    out = (A[:, :, None] * B[:, None, :]).reshape(A.shape[0], -1)
    return out


def _normalize_pairs(pairs: Iterable[Pair]) -> List[Pair]:
    """Return sorted unique pairs with i<j and i!=j."""
    P = set()
    for i, j in pairs:
        i, j = int(i), int(j)
        if i == j:
            continue
        if i > j:
            i, j = j, i
        P.add((i, j))
    return sorted(P)


@dataclass(frozen=True)
class FSPConfig:
    K: int
    M_per_dim: Union[int, Sequence[int]] = 30
    interactions: Sequence[Pair] = ()
    M_pair: Union[int, Dict[Pair, int]] = 15
    lr: float = 0.03
    l2: float = 1e-4
    seed: int = 0


class FactorizedSchedulingModel:
    """
    Numpy reference implementation of a factorized scoring model.
    """

    def __init__(self, **kwargs):
        cfg = FSPConfig(**kwargs)
        self.K = int(cfg.K)
        self.lr = float(cfg.lr)
        self.l2 = float(cfg.l2)
        self.rng = np.random.default_rng(int(cfg.seed))

        # Per-dimension basis sizes
        if isinstance(cfg.M_per_dim, int):
            self.M_phi: List[int] = [int(cfg.M_per_dim)] * self.K
        else:
            if len(cfg.M_per_dim) != self.K:
                raise ValueError("M_per_dim must have length K.")
            self.M_phi = [int(m) for m in cfg.M_per_dim]

        # Pair interactions and pair basis sizes
        self.P: List[Pair] = _normalize_pairs(cfg.interactions)

        if isinstance(cfg.M_pair, int):
            self.M_pairs: Dict[Pair, int] = {p: int(cfg.M_pair) for p in self.P}
        else:
            self.M_pairs = {}
            for (i, j), m in cfg.M_pair.items():
                key = (min(int(i), int(j)), max(int(i), int(j)))
                self.M_pairs[key] = int(m)
            # Ensure all listed interactions have sizes
            for p in self.P:
                if p not in self.M_pairs:
                    raise KeyError(f"Missing M_pair entry for interaction {p}.")

        # Build 1D bases
        self.knots: List[Array] = []
        self.Phi: List[callable] = []
        self.meanB: List[Array] = []
        self.w_phi: List[Array] = []

        grid = np.linspace(0.0, 1.0, 1001)
        for k in range(self.K):
            kn, ph = _make_hat_basis(self.M_phi[k])
            self.knots.append(kn)
            self.Phi.append(ph)
            self.meanB.append(ph(grid).mean(axis=0))
            self.w_phi.append(self.rng.normal(0.0, 0.05, size=self.M_phi[k]))

        # Build 2D pair bases (square MxM for simplicity)
        self.Phi_pair: Dict[Pair, Tuple[callable, callable]] = {}
        self.mean_pair: Dict[Pair, Tuple[Array, Array]] = {}
        self.w_psi: Dict[Pair, Array] = {}
        self.proj_pair: Dict[Pair, Array] = {}

        for (i, j) in self.P:
            M = self.M_pairs[(i, j)]
            _, Phi_i = _make_hat_basis(M)
            _, Phi_j = _make_hat_basis(M)
            self.Phi_pair[(i, j)] = (Phi_i, Phi_j)

            G = np.linspace(0.0, 1.0, 1001)
            meanBi = Phi_i(G).mean(axis=0)
            meanBj = Phi_j(G).mean(axis=0)
            self.mean_pair[(i, j)] = (meanBi, meanBj)

            self.w_psi[(i, j)] = self.rng.normal(0.0, 0.02, size=M * M)

            # Null-space projection to enforce identifiability constraints:
            # row/col marginals integrate to zero under the reference distribution.
            A_row = np.zeros((M, M * M))
            for col in range(M):
                e = np.zeros(M)
                e[col] = 1.0
                A_row[col, :] = np.kron(meanBi, e)

            A_col = np.zeros((M, M * M))
            for row in range(M):
                e = np.zeros(M)
                e[row] = 1.0
                A_col[row, :] = np.kron(e, meanBj)

            A = np.vstack([A_row, A_col])
            At = A.T
            P_null = np.eye(At.shape[0]) - At @ np.linalg.pinv(A @ At) @ A
            self.proj_pair[(i, j)] = P_null

        self.b = 0.0

    # -------------------- Core scoring --------------------

    def forward(self, X: Array) -> Tuple[Array, dict]:
        """Return scores s(X) and a cache for gradient computation."""
        X = np.asarray(X, dtype=float)
        if X.ndim != 2 or X.shape[1] != self.K:
            raise ValueError(f"X must have shape (N, {self.K}).")

        N = X.shape[0]
        s = np.full(N, self.b, dtype=float)
        cache = {"B_phi": [], "B_psi": {}, "pairs": self.P}

        for k in range(self.K):
            Bk = self.Phi[k](X[:, k])
            s += Bk @ self.w_phi[k]
            cache["B_phi"].append(Bk)

        for (i, j) in self.P:
            Phi_i, Phi_j = self.Phi_pair[(i, j)]
            Bi = Phi_i(X[:, i])
            Bj = Phi_j(X[:, j])
            Bij = _rowwise_kron(Bi, Bj)
            s += Bij @ self.w_psi[(i, j)]
            cache["B_psi"][(i, j)] = Bij

        return s, cache

    def project_identifiability(self) -> None:
        """Enforce centering constraints and absorb offsets into the bias term."""
        for k in range(self.K):
            mk = float(self.meanB[k] @ self.w_phi[k])
            denom = float(np.dot(self.meanB[k], self.meanB[k])) + 1e-12
            self.w_phi[k] -= (mk / denom) * self.meanB[k]
            self.b += mk

        for p in self.P:
            self.w_psi[p] = self.proj_pair[p] @ self.w_psi[p]

    # -------------------- Smoothness penalties --------------------

    @staticmethod
    def _smooth_1d_loss_grad(w: Array, alpha: float, beta: float) -> Tuple[float, Array]:
        """1D smoothness: first and second difference penalties."""
        w = np.asarray(w, dtype=float)
        M = w.size
        g = np.zeros_like(w)
        L = 0.0

        if M >= 2 and alpha > 0.0:
            d1 = w[1:] - w[:-1]
            L += alpha * np.mean(d1**2)
            c1 = 2.0 * alpha / (M - 1)
            g[:-1] += c1 * (-d1)
            g[1:] += c1 * (d1)

        if M >= 3 and beta > 0.0:
            d2 = w[2:] - 2.0 * w[1:-1] + w[:-2]
            L += beta * np.mean(d2**2)
            c2 = 2.0 * beta / (M - 2)
            g[:-2] += c2 * (d2)
            g[1:-1] += c2 * (-2.0 * d2)
            g[2:] += c2 * (d2)

        return float(L), g

    @staticmethod
    def _smooth_2d_loss_grad(w: Array, M: int, alpha: float, beta: float) -> Tuple[float, Array]:
        """2D smoothness on an MxM grid with first/second differences along both axes."""
        w = np.asarray(w, dtype=float)
        if w.size != M * M:
            raise ValueError("w must have length M*M.")
        W = w.reshape(M, M)
        gW = np.zeros_like(W)
        L = 0.0

        if M >= 2 and alpha > 0.0:
            dx = W[1:, :] - W[:-1, :]
            dy = W[:, 1:] - W[:, :-1]
            L += alpha * (np.mean(dx**2) + np.mean(dy**2))
            c1x = 2.0 * alpha / ((M - 1) * M)
            c1y = 2.0 * alpha / (M * (M - 1))
            gW[:-1, :] += c1x * (-dx)
            gW[1:, :] += c1x * (dx)
            gW[:, :-1] += c1y * (-dy)
            gW[:, 1:] += c1y * (dy)

        if M >= 3 and beta > 0.0:
            d2x = W[2:, :] - 2.0 * W[1:-1, :] + W[:-2, :]
            d2y = W[:, 2:] - 2.0 * W[:, 1:-1] + W[:, :-2]
            L += beta * (np.mean(d2x**2) + np.mean(d2y**2))
            c2x = 2.0 * beta / ((M - 2) * M)
            c2y = 2.0 * beta / (M * (M - 2))
            gW[:-2, :] += c2x * (d2x)
            gW[1:-1, :] += c2x * (-2.0 * d2x)
            gW[2:, :] += c2x * (d2x)
            gW[:, :-2] += c2y * (d2y)
            gW[:, 1:-1] += c2y * (-2.0 * d2y)
            gW[:, 2:] += c2y * (d2y)

        return float(L), gW.reshape(-1)

    # -------------------- Training --------------------

    def _loss_and_grads(
        self,
        X: Array,
        a: int,
        r: float,
        X_next: Array,
        gamma: float = 0.0,
        lam_val: float = 0.1,
        alpha_phi: float = 1e-3,
        beta_phi: float = 1e-4,
        alpha_psi: float = 1e-3,
        beta_psi: float = 1e-4,
    ) -> Tuple[dict, dict]:
        """Compute loss components and gradients for a single transition."""
        s, cache = self.forward(X)
        V = float(s.mean())

        s_next, _ = self.forward(X_next)
        Vn = float(s_next.mean())
        y = float(r + gamma * Vn)

        a = int(a)
        if not (0 <= a < s.shape[0]):
            raise IndexError("Action index a is out of bounds.")

        err = float(r - s[a])
        L_band = err**2
        L_val = (y - V) ** 2

        reg_phi = sum(float(np.sum(w**2)) for w in self.w_phi)
        reg_psi = sum(float(np.sum(w**2)) for w in self.w_psi.values())
        L_reg = self.l2 * (reg_phi + reg_psi)

        # Smoothness
        L_smooth_phi = 0.0
        g_smooth_phi: List[Array] = []
        for w in self.w_phi:
            Lk, gk = self._smooth_1d_loss_grad(w, alpha_phi, beta_phi)
            L_smooth_phi += Lk
            g_smooth_phi.append(gk)

        L_smooth_psi = 0.0
        g_smooth_psi: Dict[Pair, Array] = {}
        for p in self.P:
            M = self.M_pairs[p]
            Lp, gp = self._smooth_2d_loss_grad(self.w_psi[p], M, alpha_psi, beta_psi)
            L_smooth_psi += Lp
            g_smooth_psi[p] = gp

        L_total = L_band + lam_val * L_val + L_reg + L_smooth_phi + L_smooth_psi

        # Gradients
        gb = 2.0 * (s[a] - r) + lam_val * (-2.0 * (y - V)) * 1.0

        gw_phi: List[Array] = []
        for k in range(self.K):
            Bk = cache["B_phi"][k]
            g_band = 2.0 * (s[a] - r) * Bk[a]
            g_val = lam_val * (-2.0 * (y - V)) * Bk.mean(axis=0)
            g_reg = 2.0 * self.l2 * self.w_phi[k]
            gw_phi.append(g_band + g_val + g_reg + g_smooth_phi[k])

        gw_psi: Dict[Pair, Array] = {}
        for p in self.P:
            Bij = cache["B_psi"][p]
            g_band = 2.0 * (s[a] - r) * Bij[a]
            g_val = lam_val * (-2.0 * (y - V)) * Bij.mean(axis=0)
            g_reg = 2.0 * self.l2 * self.w_psi[p]
            gw_psi[p] = g_band + g_val + g_reg + g_smooth_psi[p]

        loss = {
            "bandit": float(L_band),
            "value": float(L_val),
            "reg": float(L_reg),
            "smooth_phi": float(L_smooth_phi),
            "smooth_psi": float(L_smooth_psi),
            "total": float(L_total),
        }
        grads = {"b": float(gb), "phi": gw_phi, "psi": gw_psi}
        return loss, grads

    def train_step(
        self,
        X: Array,
        a: int,
        r: float,
        X_next: Array,
        gamma: float = 0.0,
        lam_val: float = 0.1,
        alpha_phi: float = 1e-3,
        beta_phi: float = 1e-4,
        alpha_psi: float = 1e-3,
        beta_psi: float = 1e-4,
    ) -> dict:
        """One SGD step on a single transition."""
        loss, grads = self._loss_and_grads(
            X, a, r, X_next,
            gamma=gamma, lam_val=lam_val,
            alpha_phi=alpha_phi, beta_phi=beta_phi,
            alpha_psi=alpha_psi, beta_psi=beta_psi,
        )

        self.b -= self.lr * grads["b"]
        for k in range(self.K):
            self.w_phi[k] -= self.lr * grads["phi"][k]
        for p in self.P:
            self.w_psi[p] -= self.lr * grads["psi"][p]

        self.project_identifiability()
        return loss

    # -------------------- Inspection helpers --------------------

    def components_1d(self, num: int = 121) -> Tuple[List[Array], List[Array]]:
        """Return per-dimension grids and centered phi_k values for plotting."""
        grids = [np.linspace(0.0, 1.0, num) for _ in range(self.K)]
        phis: List[Array] = []
        for k in range(self.K):
            Bk = self.Phi[k](grids[k])
            phi = Bk @ self.w_phi[k]
            phis.append(phi - phi.mean())
        return grids, phis

    def component_2d_slice(self, i: int, j: int, xi_grid: Array, xj_fixed: float) -> Array:
        """
        Return a centered slice psi_{ij}(xi, xj_fixed) along xi_grid.
        """
        key = (min(int(i), int(j)), max(int(i), int(j)))
        if key not in self.w_psi:
            raise KeyError(f"Pair {key} is not in interactions.")
        Phi_i, Phi_j = self.Phi_pair[key]
        Bi = Phi_i(np.asarray(xi_grid, dtype=float).reshape(-1))
        Bj = Phi_j(np.full(Bi.shape[0], float(xj_fixed)))
        Bij = _rowwise_kron(Bi, Bj)
        psi = Bij @ self.w_psi[key]
        return psi - psi.mean()