# dgp_torch_optimized.py
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Literal

import torch
import numpy as np
from itertools import product


EmissionType = Literal["linear", "nonlinear"]


# ============================================================
# Config / Params
# ============================================================
@dataclass
class DGPConfig:
    """
    Data generating process configuration.

    taus: number of time slices within a day (T)
    p: state dimension
    emission: "linear" or "nonlinear"
    phi1_range: scale for Phi_t entries ~ Unif(-phi1_range, phi1_range)
    Xi1_range: scale for Xi_t entries ~ Unif(-Xi1_range, Xi1_range)
    alpha_mean, alpha_sd: carryover strength (Gamma_t ~ N(alpha_mean, alpha_sd^2))
    add_state_noise: whether to add Gaussian noise to transitions
    round_digits: optional rounding (e.g., 3) to mimic R's round(..., digits=3)
    """
    taus: int = 48
    p: int = 3
    seed: int = 2026
    dtype: torch.dtype = torch.float32
    device: Optional[torch.device] = None

    emission: EmissionType = "linear"

    # transition coefficients scale
    phi1_range: float = 0.3
    Xi1_range: float = 0.2

    # carryover (Gamma_t) distribution
    alpha_mean: float = 0.0
    alpha_sd: float = 0.3

    # optional transition noise
    add_state_noise: bool = False

    # optional rounding
    round_digits: Optional[int] = None


@dataclass
class DGPParams:
    """
    DGP parameters.

    beta0_t: scalar intercept in outcome
    beta1_t: linear state coef in outcome
    xi1_t: interaction coef for a * X in outcome
    gammas_t: direct treatment effect coef (a * gamma_t)
    phi0_t: intercept in transition
    phi1_t: linear state transition matrix Phi_t
    Xi1_t: interaction transition matrix Xi_t (multiplied by action)
    alphas_t: carryover vector Gamma_t (multiplied by action)
    """
    taus: int
    p: int
    beta0: torch.Tensor        # (T,)##constant
    beta1: torch.Tensor        # (T,p)#
    xi1: torch.Tensor          # (T,p)##
    gammas: torch.Tensor       # (T,)#
    phi0: torch.Tensor        
    phi1: torch.Tensor         # (T,p,p)
    Xi1: torch.Tensor          # (T,p,p)
    alphas: torch.Tensor       # (T,p)
    eps_state_cov: torch.Tensor  # (p,p)


# ============================================================
# Utilities
# ============================================================
def set_torch_seed(seed: int) -> None:
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def _maybe_round(x: torch.Tensor, digits: Optional[int]) -> torch.Tensor:
    if digits is None:
        return x
    scale = 10.0 ** digits
    return torch.round(x * scale) / scale


def _rand_sign_uniform(low_neg: float, high_neg: float, low_pos: float, high_pos: float,
                       shape: Tuple[int, ...], device=None, dtype=None) -> torch.Tensor:
    """Elementwise: 50% from [low_neg, high_neg], 50% from [low_pos, high_pos]."""
    coin = torch.rand(shape, device=device, dtype=dtype) > 0.5
    neg = (high_neg - low_neg) * torch.rand(shape, device=device, dtype=dtype) + low_neg
    pos = (high_pos - low_pos) * torch.rand(shape, device=device, dtype=dtype) + low_pos
    return torch.where(coin, neg, pos)


# ============================================================
# Parameter generation
# ============================================================
def generate_params(cfg: DGPConfig) -> DGPParams:
    device = cfg.device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    set_torch_seed(cfg.seed)

    T, p, dtype = cfg.taus, cfg.p, cfg.dtype

    beta0 = _maybe_round(
        _rand_sign_uniform(-1.0, -0.5, 0.5, 1.0, (T,), device=device, dtype=dtype),
        cfg.round_digits
    )

    beta1 = _maybe_round(
        _rand_sign_uniform(-0.3, -0.1, 0.1, 0.3, (T, p), device=device, dtype=dtype),
        cfg.round_digits
    )

    # outcome interaction coef xi1: mixture [-0.1,-0.05] vs [0.05,0.1]
    xi1 = _maybe_round(
        _rand_sign_uniform(-0.1, -0.05, 0.05, 0.1, (T, p), device=device, dtype=dtype),
        cfg.round_digits
    )

    # gamma_t: Uniform(0.5, 0.8)
    gammas = _maybe_round(
        (0.8 - 0.5) * torch.rand((T,), device=device, dtype=dtype) + 0.5,
        cfg.round_digits
    )

    phi0 = _maybe_round(
        _rand_sign_uniform(-1.0, -0.5, 0.5, 1.0, (T, p), device=device, dtype=dtype),
        cfg.round_digits
    )

    # Phi_t ~ Unif(-r, r)
    r = float(cfg.phi1_range)
    phi1 = _maybe_round(
        (2 * r) * torch.rand((T, p, p), device=device, dtype=dtype) - r,
        cfg.round_digits
    )

    # Xi_t ~ Unif(-r_xi, r_xi)
    r_xi = float(cfg.Xi1_range)
    Xi1 = _maybe_round(
        (2 * r_xi) * torch.rand((T, p, p), device=device, dtype=dtype) - r_xi,
        cfg.round_digits
    )

    # Gamma_t ~ Normal(alpha_mean, alpha_sd^2)
    alphas = _maybe_round(
        cfg.alpha_mean + cfg.alpha_sd * torch.randn((T, p), device=device, dtype=dtype),
        cfg.round_digits
    )

    # transition noise covariance (optional)
    eps_state_cov = torch.diag(torch.full((p,), 1.5, device=device, dtype=dtype))

    return DGPParams(
        taus=T, p=p,
        beta0=beta0, beta1=beta1, xi1=xi1,
        gammas=gammas,
        phi0=phi0, phi1=phi1, Xi1=Xi1,
        alphas=alphas,
        eps_state_cov=eps_state_cov
    )



# ============================================================
# Stability checks
# ============================================================
def check_transition_stability(phi1: torch.Tensor) -> torch.Tensor:
    """(T,) bool: whether all eigenvalue magnitudes < 1 for each Phi_t."""
    eigvals = torch.linalg.eigvals(phi1)
    return (eigvals.abs() < 1.0).all(dim=-1)#


def check_transition_stability_signed(phi1: torch.Tensor, Xi1: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Check stability for Phi_t, Phi_t + Xi_t, Phi_t - Xi_t (action = +1/-1).
    Returns dict of (T,) bool.
    """
    eig_phi = torch.linalg.eigvals(phi1).abs()##
    eig_plus = torch.linalg.eigvals(phi1 + Xi1).abs()##
    eig_minus = torch.linalg.eigvals(phi1 - Xi1).abs()
    return {
        "phi": (eig_phi < 1.0).all(dim=-1),
        "plus": (eig_plus < 1.0).all(dim=-1),
        "minus": (eig_minus < 1.0).all(dim=-1),
    }


# ============================================================
# Emission / Transition models
# ============================================================
def emission_distribution(
    beta0_t: torch.Tensor,     # scalar or ()
    beta1_t: torch.Tensor,     # (p,)
    xi1_t: torch.Tensor,       # (p,)
    states_t: torch.Tensor,    # (N,p
    act,                       # scalar
    gamma_t: torch.Tensor,     # scalar or ()
    emission: EmissionType = "linear",
) -> torch.Tensor:
    """
    Return y: (N,)

    Linear:
      y = beta0 + X beta1 + a*gamma + a*(X xi1)

    Nonlinear:
      y = beta0
          + beta1^T [ 2*(sin(X*a)+cos(X))^2 + 3*X*gamma*a ]
          + (a*gamma + cos(a*gamma))^2
          + a*(X xi1)
    """
    device = states_t.device
    dtype = states_t.dtype
    N = states_t.shape[0]

    # ---- normalize act to shape (N,) ----
    act_t = torch.as_tensor(act, device=device, dtype=dtype)
    if act_t.ndim == 0:
        act_t = act_t.expand(N)              # scalar -> (N,)
    elif act_t.ndim == 2 and act_t.shape[1] == 1:
        act_t = act_t.squeeze(1)             # (N,1) -> (N,)
    elif act_t.ndim == 1:
        if act_t.shape[0] != N:
            raise ValueError(f"act has shape {act_t.shape}, expected (N,) with N={N}")
    else:
        raise ValueError(f"act must be scalar, (N,), or (N,1). Got {act_t.shape}")
    

    # precompute
    xb = states_t @ beta1_t         # (N,)
    xxi = states_t @ xi1_t          # (N,)##

    if emission == "linear":
        return beta0_t + xb + act_t * gamma_t + act_t * xxi###

    # ---- nonlinear ----
    # states_t * act_t needs act_t as (N,1) for proper broadcasting across p
    act_col = act_t[:, None]        # (N,1)
 
    #return beta0_t + nonlinear_part + scalar_part + interaction_part
    cos_part = torch.sum(torch.cos(states_t), dim=1)
    Y=beta0_t + xb + act_t * gamma_t + act_t * xxi+(cos_part)##
    return Y
    



def trans_distribution(
    phi0_t: torch.Tensor,     # (p,) (1,p)
    phi1_t: torch.Tensor,     # (p,p)
    Xi1_t: torch.Tensor,      # (p,p)
    states_t: torch.Tensor,   # (N,p)
    act,                      # scalar / (N,) / (N,1)
    alpha_t: torch.Tensor,    # (p,) (1,p)
    state_noise: Optional[torch.Tensor] = None,  # (N,p)
) -> torch.Tensor:
    """
    Transition with action-state interaction:

      x_{t+1} = phi0 + states_t @ Phi^T + a*alpha + a*(states_t @ Xi^T) + noise

    Interpretation:
      - For each day i, action a_i is a scalar applied to ALL p state dimensions.
      - So (a*alpha)_i is a_i * alpha (same alpha vector scaled by a_i).
      - And a*(states_t @ Xi^T) scales the whole (N,p) interaction term row-wise.

    Shapes:
      - states_t: (N,p)
      - act: scalar or (N,) or (N,1)  -> normalized to a: (N,1)
      - phi0_t: (p,) -> normalized to (1,p)
      - alpha_t:(p,) -> normalized to (1,p)
      - output: (N,p)
    """
    if states_t.ndim != 2:
        raise ValueError(f"states_t must be 2D (N,p), got {states_t.shape}")

    device = states_t.device
    dtype = states_t.dtype
    N, p = states_t.shape

    # ---- normalize act to (N,1) ----
    a = torch.as_tensor(act, device=device, dtype=dtype)
    if a.ndim == 0:
        a = a.expand(N).view(N, 1)                  # scalar -> (N,1)
    elif a.ndim == 1:
        if a.shape[0] != N:
            raise ValueError(f"act has shape {a.shape}, expected (N,) with N={N}")
        a = a.view(N, 1)                            # (N,) -> (N,1)
    elif a.ndim == 2 and a.shape == (N, 1):
        pass                                        # already (N,1)
    else:
        raise ValueError(f"act must be scalar, (N,), or (N,1). Got {a.shape}")

    # ---- normalize phi0_t to (1,p) ----
    phi0 = torch.as_tensor(phi0_t, device=device, dtype=dtype)
    if phi0.ndim == 1:
        if phi0.shape[0] != p:
            raise ValueError(f"phi0_t has shape {phi0.shape}, expected ({p},)")
        phi0 = phi0.view(1, p)
    elif phi0.ndim == 2:
        if phi0.shape != (1, p):
            raise ValueError(f"phi0_t has shape {phi0.shape}, expected (1,{p})")
    else:
        raise ValueError(f"phi0_t must be (p,) or (1,p). Got {phi0.shape}")

    # ---- normalize alpha_t to (1,p) ----
    alpha = torch.as_tensor(alpha_t, device=device, dtype=dtype)
    if alpha.ndim == 1:
        if alpha.shape[0] != p:
            raise ValueError(f"alpha_t has shape {alpha.shape}, expected ({p},)")
        alpha = alpha.view(1, p)
    elif alpha.ndim == 2:
        if alpha.shape != (1, p):
            raise ValueError(f"alpha_t has shape {alpha.shape}, expected (1,{p})")
    else:
        raise ValueError(f"alpha_t must be (p,) or (1,p). Got {alpha.shape}")

    # ---- compute components ----
    base = phi0 + (states_t @ phi1_t.T)             # (N,p) via broadcast (1,p)+(N,p)
    inter_term = states_t @ Xi1_t.T                 # (N,p)

    nxt = base + a * alpha + a * inter_term         # (N,p)

    if state_noise is not None:
        if state_noise.shape != (N, p):
            raise ValueError(f"state_noise shape {state_noise.shape} != {(N,p)}")
        nxt = nxt + state_noise

    return nxt

# ============================================================
# Analytic ATE (linear) for (+1) vs (-1) via recursion (fast)
# ============================================================
@torch.no_grad()
def compute_ATE_linear_true(
    params: DGPParams,
    act_hi: float = 1.0,
    act_lo: float = -1.0,
    init_state_mean: Optional[torch.Tensor] = None,   # (p,) default 0###
) -> Dict[str, float]:
    """
    Exact analytic ATE for the LINEAR model under two constant policies:
      all actions = act_hi  versus all actions = act_lo.

    Uses the recursion for mean states:
      m_{t+1}^{(a)} = phi0_t + (Phi_t + a*Xi_t) m_t^{(a)} + a*Gamma_t
    and mean outcomes:
      E[Y_t^{(a)}] = beta0_t + (beta1_t + a*xi_t)^T m_t^{(a)} + a*gamma_t
    """
    device, dtype = params.beta0.device, params.beta0.dtype
    T, p = params.taus, params.p

    if init_state_mean is None:
        m_hi = torch.zeros((p,), device=device, dtype=dtype)
        m_lo = torch.zeros((p,), device=device, dtype=dtype)
    else:
        m_hi = init_state_mean.to(device=device, dtype=dtype).clone()
        m_lo = init_state_mean.to(device=device, dtype=dtype).clone()

    sumY_hi = torch.zeros((), device=device, dtype=dtype)
    sumY_lo = torch.zeros((), device=device, dtype=dtype)

    aH = torch.as_tensor(act_hi, device=device, dtype=dtype)
    aL = torch.as_tensor(act_lo, device=device, dtype=dtype)

    for t in range(T):
        EY_hi = params.beta0[t] + (params.beta1[t] + aH * params.xi1[t]) @ m_hi + aH * params.gammas[t]
        EY_lo = params.beta0[t] + (params.beta1[t] + aL * params.xi1[t]) @ m_lo + aL * params.gammas[t]
        sumY_hi += EY_hi
        sumY_lo += EY_lo

        A_hi = params.phi1[t] + aH * params.Xi1[t]
        A_lo = params.phi1[t] + aL * params.Xi1[t]
        m_hi = params.phi0[t] + (A_hi @ m_hi) + aH * params.alphas[t]
        m_lo = params.phi0[t] + (A_lo @ m_lo) + aL * params.alphas[t]

    ate = (sumY_hi - sumY_lo) / T###
    return {
        "Policy_hi_mean_totalY": float(sumY_hi.item()),
        "Policy_lo_mean_totalY": float(sumY_lo.item()),
        "ATE_linear_true": float(ate.item()),
        # direct effect difference from a*gamma only: (+1)-(-1)=2*gamma_t
        "DE_only_component": float((2.0 * params.gammas.mean()).item()),
    }


# ============================================================
# Closed-form ATE (linear) aligned with Eq.(49) (slower, for verification)
# ============================================================
@torch.no_grad()

def mean_state_closed_form(
    phi0: torch.Tensor, phi1: torch.Tensor, Xi1: torch.Tensor, alphas: torch.Tensor,
    t_1based: int, a: float, EX1: torch.Tensor
) -> torch.Tensor:
    """
    Compute E[X_t^{(a)}] in closed form (paper-style), with t being 1-based:
      E[X_t^{(a)}] = sum_{k=1}^{t-1} [ prod_{l=k+1}^{t-1} (Phi_l + a Xi_l) ] (phi0_k + a Gamma_k)
                   + prod_{l=1}^{t-1} (Phi_l + a Xi_l) E[X_1]

    Conventions:
      sum_{k=1}^{0} = 0
      prod_{l=1}^{0} = I
    """
    device, dtype = phi0.device, phi0.dtype
    p = phi0.shape[1]

    if t_1based <= 1:
        return EX1

    A = phi1 + a * Xi1  # (T,p,p)
    out = torch.zeros((p,), device=device, dtype=dtype)

    # sum over k
    for k_1based in range(1, t_1based):  # 1..t-1
        v = phi0[k_1based - 1] + a * alphas[k_1based - 1]  # (p,)
        # prod_{l=k+1}^{t-1} A_l
        if k_1based <= t_1based - 2:
            P = torch.eye(p, device=device, dtype=dtype)
            for l_1based in range(k_1based + 1, t_1based):  # k+1..t-1
                P = A[l_1based - 1] @ P
            v = P @ v
        out = out + v

    # prod_{l=1}^{t-1} A_l E[X1]
    P1 = torch.eye(p, device=device, dtype=dtype)
    for l_1based in range(1, t_1based):
        P1 = A[l_1based - 1] @ P1
    out = out + P1 @ EX1
    return out



@torch.no_grad()
def compute_ATE_linear_true_closed_form(
    params: DGPParams,
    init_state_mean: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
    """
    Compute ATE in the linear model using a closed-form expression aligned with Eq.(49).
    Mainly for verification against the fast recursion.
    """
    device, dtype = params.beta0.device, params.beta0.dtype
    T, p = params.taus, params.p
    EX1 = torch.zeros((p,), device=device, dtype=dtype) if init_state_mean is None else init_state_mean.to(device=device, dtype=dtype)

    T_ATE = 2.0 * params.gammas.sum()

    for t in range(1, T + 1):
        mx_plus = mean_state_closed_form(params.phi0, params.phi1, params.Xi1, params.alphas, t, +1.0, EX1)
        mx_minus = mean_state_closed_form(params.phi0, params.phi1, params.Xi1, params.alphas, t, -1.0, EX1)

        bt = params.beta1[t - 1]
        xt = params.xi1[t - 1]
        T_ATE = T_ATE + (bt + xt) @ mx_plus - (bt - xt) @ mx_minus

    ate = T_ATE / T
    return {
        "T_times_ATE_closed": float(T_ATE.item()),
        "ATE_linear_true_closed": float(ate.item()),
    }###



@torch.no_grad()
def compute_index_u_vectors(
    params: DGPParams,
    init_state_mean: Optional[torch.Tensor] = None,
    act_hi: float = +1.0,
    act_lo: float = -1.0,
    return_matrix: bool = True,
) -> Dict[str, torch.Tensor]:
    """
    Build u_t for t=1..T:
      u_t = ( 0,
              [E^{+}(X_t) - E^{-}(X_t)]^T,
              2,
              [E^{+}(X_t) + E^{-}(X_t)]^T )^T
    Then stack/flatten:
      u = (u_1^T, ..., u_T^T)^T.

    Returns:
      - "u": (T*(2p+2),) flattened vector
      - "U": (T, 2p+2) matrix where row t-1 is u_t^T  (optional)
    """
    device, dtype = params.beta0.device, params.beta0.dtype
    T, p = params.taus, params.p

    EX1 = torch.zeros((p,), device=device, dtype=dtype) if init_state_mean is None \
        else init_state_mean.to(device=device, dtype=dtype).reshape(p,)

    u_rows = []
    for t in range(1, T + 1):
        mx_plus  = mean_state_closed_form(params.phi0, params.phi1, params.Xi1, params.alphas,
                                          t, act_hi, EX1).reshape(p,)
        mx_minus = mean_state_closed_form(params.phi0, params.phi1, params.Xi1, params.alphas,
                                          t, act_lo, EX1).reshape(p,)

        diff = (mx_plus - mx_minus)   # E^{+} - E^{-}  (p,)
        summ = (mx_plus + mx_minus)   # E^{+} + E^{-}  (p,)

        # u_t shape: (2p+2,)
        u_t = torch.cat([
            torch.zeros((1,), device=device, dtype=dtype),   # 0
            diff,                                            # (p,)
            torch.full((1,), 2.0, device=device, dtype=dtype),# 2
            summ                                             # (p,)
        ], dim=0)

        u_rows.append(u_t)

    U = torch.stack(u_rows, dim=0)          # (T, 2p+2)
    u = U.reshape(-1)                       # (T*(2p+2),)

    if return_matrix:
        return {"u": u, "U": U}
    return {"u": u}###


# ============================================================
# Monte Carlo ATE for (+1) vs (-1)
# ============================================================
@torch.no_grad()
def simulate_policy_totals(
    params: DGPParams,
    N: int,
    emission: EmissionType = "linear",
    seed: int = 2026,
    add_state_noise: bool = False,
    act_hi: float = 1.0,
    act_lo: float = -1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:###
    """
    Simulate N trajectories under constant actions act_hi and act_lo.
    Outcome noise eps1 is shared between the two policies at each (t, trajectory)
    (common random numbers) to reduce MC variance.
    """
    device = params.beta0.device
    dtype = params.beta0.dtype
    set_torch_seed(seed)

    T, p = params.taus, params.p

    # initial states: N(0, I)
    states_hi = torch.randn((N, p), device=device, dtype=dtype)
    states_lo = states_hi.clone()

    sumY_hi = torch.zeros((N,), device=device, dtype=dtype)
    sumY_lo = torch.zeros((N,), device=device, dtype=dtype)

    chol = None
    if add_state_noise:
        chol = torch.linalg.cholesky(params.eps_state_cov)

    for t in range(T):
        eps1 = torch.randn((N,), device=device, dtype=dtype)

        y_hi = emission_distribution(
            params.beta0[t], params.beta1[t], params.xi1[t],
            states_hi, act_hi, params.gammas[t], emission
        ) + eps1

        y_lo = emission_distribution(
            params.beta0[t], params.beta1[t], params.xi1[t],
            states_lo, act_lo, params.gammas[t], emission
        ) + eps1

        sumY_hi += y_hi
        sumY_lo += y_lo

        noise = None
        if add_state_noise:
            z = torch.randn((N, p), device=device, dtype=dtype)
            noise = z @ chol.T

        states_hi = trans_distribution(
            params.phi0[t], params.phi1[t], params.Xi1[t],
            states_hi, act_hi, params.alphas[t], noise
        )
        states_lo = trans_distribution(
            params.phi0[t], params.phi1[t], params.Xi1[t],
            states_lo, act_lo, params.alphas[t], noise
        )

    return sumY_hi.mean(), sumY_lo.mean()



@torch.no_grad()
def monte_carlo_ATE_true(
    params: DGPParams,
    Iter_n: List[int] = (500, 1000),
    emission: EmissionType = "linear",
    seed: int = 2026,
    add_state_noise: bool = False,
    act_hi: float = 1.0,
    act_lo: float = -1.0,
    verify_closed_form: bool = True,
) -> Dict[int, Dict[str, float]]:

    T = params.taus
    out: Dict[int, Dict[str, float]] = {}

    analytic = None
    closed = None
    if emission == "linear":
        analytic = compute_ATE_linear_true(params, act_hi=act_hi, act_lo=act_lo)
        out[-1] = analytic

        if verify_closed_form:
            closed = compute_ATE_linear_true_closed_form(params)
            out[-2] = closed
            out[-3] = {"rec_minus_closed": float(analytic["ATE_linear_true"] - closed["ATE_linear_true_closed"])}

    for N in Iter_n:
        total_hi, total_lo = simulate_policy_totals(
            params, N,
            emission=emission,
            seed=seed,
            add_state_noise=add_state_noise,
            act_hi=act_hi, act_lo=act_lo
        )
        ate_mc = (total_hi - total_lo) / T
        row = {
            "Policy_hi_true": float(total_hi.item()),
            "Policy_lo_true": float(total_lo.item()),
            "ATE_true_MC": float(ate_mc.item()),
        }
        if analytic is not None:
            row["MC_minus_analytic"] = float(ate_mc.item() - analytic["ATE_linear_true"])
        out[N] = row

    return out


@torch.no_grad()
def monte_carlo_ATE_with_linear_baseline(
    params: DGPParams,
    Iter_n: List[int] = (500, 1000),
    emission: EmissionType = "linear",
    seed: int = 2026,
    add_state_noise: bool = False,
    act_hi: float = 1.0,
    act_lo: float = -1.0,
    include_closed_form: bool = True,
) -> Dict[int, Dict[str, float]]:
    """
    Always report:
      - Linear analytic ATE baseline (recursion)
      - (Optional) closed-form Eq.(49) baseline
    And report MC ATE under the chosen emission (linear or nonlinear).
    In nonlinear case: baseline is NOT the true ATE, but a comparator.

    Output keys:
      [-1] recursion baseline
      [-2] closed-form baseline (if included)
      [-3] baseline consistency check
      [N]  MC estimates (+ bias vs baseline)
    """
    T = params.taus
    out: Dict[int, Dict[str, float]] = {}

    # --- Baseline: linear analytic (recursion) ---
    baseline_rec = compute_ATE_linear_true(params, act_hi=act_hi, act_lo=act_lo, init_state_mean=None)
    out[-1] = {
        **baseline_rec,
        "note": "Linear analytic baseline (valid as true ATE only when emission='linear')."
    }

    # --- Baseline: closed-form (Eq.49) optional ---
    baseline_cf = None
    if include_closed_form:
        baseline_cf = compute_ATE_linear_true_closed_form(params, init_state_mean=None)
        out[-2] = {
            **baseline_cf,
            "note": "Closed-form baseline aligned with Eq.(49) (linear model)."
        }
        out[-3] = {
            "rec_minus_closed": float(baseline_rec["ATE_linear_true"] - baseline_cf["ATE_linear_true_closed"])
        }

    # --- MC under chosen emission ---
    for N in Iter_n:
        total_hi, total_lo = simulate_policy_totals(
            params, N,
            emission=emission,
            seed=seed,
            add_state_noise=add_state_noise,
            act_hi=act_hi,
            act_lo=act_lo
        )
        ate_mc = (total_hi - total_lo) / T

        row = {
            "Policy_hi_true": float(total_hi.item()),
            "Policy_lo_true": float(total_lo.item()),
            "ATE_true_MC": float(ate_mc.item()),
            "ATE_linear_baseline": float(baseline_rec["ATE_linear_true"]),
            "MC_minus_linear_baseline": float(ate_mc.item() - baseline_rec["ATE_linear_true"]),
            "DE_only_component": float(baseline_rec["DE_only_component"]),
        }

        # In nonlinear case, emphasize interpretation
        if emission == "nonlinear":
            row["interpretation"] = "MC is nonlinear ATE estimate; linear_baseline is comparator (misspecification gap)."

        out[N] = row

    return out




NoiseType = Literal["iid", "ar1", "exchangeable", "ma1"]

@dataclass
class OutcomeNoiseConfig:
    type: NoiseType = "iid"
    sigma: float = 1.0
    rho: float = 0.6
    theta: float = 0.6
    jitter: float = 1e-6



def _toeplitz_ar1_cov(T: int, rho: float, sigma: float, device, dtype) -> torch.Tensor:
    idx = torch.arange(T, device=device)
    dist = (idx[:, None] - idx[None, :]).abs()
    cov = (rho ** dist) * (sigma ** 2)
    return cov.to(dtype=dtype)


def _exchangeable_cov(T: int, rho: float, sigma: float, device, dtype) -> torch.Tensor:
    I = torch.eye(T, device=device, dtype=dtype)
    J = torch.ones((T, T), device=device, dtype=dtype)
    cov = (sigma ** 2) * ((1.0 - rho) * I + rho * J)
    return cov


def _ma1_cov(T: int, theta: float, sigma: float, device, dtype) -> torch.Tensor:
    """
    Approximate MA(1) covariance:
      e_t = u_t + theta u_{t-1},  u_t iid N(0, sigma_u^2)

    Choose sigma_u^2 so that interior Var(e_t) ≈ sigma^2:
      Var(e_t)=sigma_u^2(1+theta^2)=sigma^2  => sigma_u^2=sigma^2/(1+theta^2)

    Then:
      Cov(e_t, e_{t-1}) = sigma_u^2 * theta
      Cov(e_t, e_{t-k}) = 0 for k>=2
    """
    sigma_u2 = (sigma ** 2) / (1.0 + theta ** 2)
    Sigma = torch.zeros((T, T), device=device, dtype=dtype)
    Sigma.fill_diagonal_(sigma ** 2)
    if T >= 2:
        off = sigma_u2 * theta
        idx = torch.arange(T - 1, device=device)
        Sigma[idx, idx + 1] = off
        Sigma[idx + 1, idx] = off
    return Sigma

def sample_outcome_noise(
    ndays: int,
    T: int,
    cfg: OutcomeNoiseConfig,
    device: torch.device,
    dtype: torch.dtype,
    seed: Optional[int] = None,
    return_chol: bool = False,
) -> Dict[str, torch.Tensor]:
    """
    Return dict:
      - eps:   (ndays, T) noise samples
      - Sigma: (T, T) within-day covariance (the one used for sampling; includes jitter where relevant)
      - chol:  (T, T) Cholesky factor (optional)

    Days are iid; correlation is only within each day across time.
    """
    device = torch.device(device) if not isinstance(device, torch.device) else device

    if seed is not None:
        set_torch_seed(seed)

    # ---- IID ----
    if cfg.type == "iid":
        eps = cfg.sigma * torch.randn((ndays, T), device=device, dtype=dtype)
        Sigma = (cfg.sigma ** 2) * torch.eye(T, device=device, dtype=dtype)
        out = {"eps": eps, "Sigma": Sigma}
        if return_chol:
            out["chol"] = cfg.sigma * torch.eye(T, device=device, dtype=dtype)
        return out

    # ---- AR1 / Exchangeable ----
    if cfg.type in ("ar1", "exchangeable"):
        if cfg.type == "ar1":
            Sigma = _toeplitz_ar1_cov(T, cfg.rho, cfg.sigma, device, dtype)
        else:
            Sigma = _exchangeable_cov(T, cfg.rho, cfg.sigma, device, dtype)

        Sigma_used = Sigma + cfg.jitter * torch.eye(T, device=device, dtype=dtype)
        chol = torch.linalg.cholesky(Sigma_used)  # (T,T)

        z = torch.randn((ndays, T), device=device, dtype=dtype)
        eps = z @ chol.T

        out = {"eps": eps, "Sigma": Sigma_used}
        if return_chol:
            out["chol"] = chol
        return out

    # ---- MA(1) ----
    if cfg.type == "ma1":
        Sigma = _ma1_cov(T, cfg.theta, cfg.sigma, device, dtype)

        sigma_u = cfg.sigma / ((1.0 + cfg.theta ** 2) ** 0.5)
        u = sigma_u * torch.randn((ndays, T), device=device, dtype=dtype)
        eps = u.clone()
        eps[:, 1:] = eps[:, 1:] + cfg.theta * u[:, :-1]

        out = {"eps": eps, "Sigma": Sigma}
        if return_chol:
            Sigma_used = Sigma + cfg.jitter * torch.eye(T, device=device, dtype=dtype)
            out["chol"] = torch.linalg.cholesky(Sigma_used)
        return out

    raise ValueError(f"Unknown noise type: {cfg.type}")







def time_features(t_idx0: int, T: int, device, dtype) -> torch.Tensor:
    tt = torch.tensor(float(t_idx0 + 1), device=device, dtype=dtype)
    TT = torch.tensor(float(T), device=device, dtype=dtype)
    ang = 2.0 * torch.pi * tt / TT
    return torch.stack([torch.sin(ang), torch.cos(ang)], dim=0) 






def make_psi_legendre_tensor_torch(
    S: torch.Tensor,####
    degree: int = 3,
    scaler: Optional[Dict[str, torch.Tensor]] = None,
    include_intercept: bool = True,
    max_total_degree: Optional[int] = None,
    interaction_order: Optional[int] = 1,
    per_dim_degree: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[str], Dict[str, torch.Tensor]]:
    """
    Torch version of multivariate Legendre tensor-product basis.

    Args:
      S: (n, d) tensor
      scaler: None or {"min": (d,), "max": (d,)} tensors (same device/dtype as S recommended)

    Returns:
      Psi: (n, L) tensor
      names: list[str], length L
      scaler: dict with 'min','max' as (d,) tensors
    """
    if S.ndim != 2:
        raise ValueError(f"S must be (n,d), got {tuple(S.shape)}")
    n, d = S.shape
    device, dtype = S.device, S.dtype
    # per-dimension degrees
    if per_dim_degree is None:
        per_dim_degree_list = [int(degree)] * d
    else:
        if isinstance(per_dim_degree, torch.Tensor):
            per_dim_degree_list = [int(x) for x in per_dim_degree.detach().cpu().tolist()]
        else:
            per_dim_degree_list = [int(x) for x in per_dim_degree]
        if len(per_dim_degree_list) != d:
            raise ValueError("per_dim_degree must have length d")
        if any(k < 0 for k in per_dim_degree_list):
            raise ValueError("per_dim_degree entries must be >= 0")

    # scaling to [-1,1]
    if scaler is None:
        col_min = S.amin(dim=0)
        col_max = S.amax(dim=0)
        scaler = {"min": col_min, "max": col_max}
    else:
        col_min = torch.as_tensor(scaler["min"], device=device, dtype=dtype)
        col_max = torch.as_tensor(scaler["max"], device=device, dtype=dtype)
        if col_min.shape != (d,) or col_max.shape != (d,):
            raise ValueError("scaler['min'] and scaler['max'] must be shape (d,)")

    rng = col_max - col_min
    rng_safe = torch.where(rng == 0, torch.ones_like(rng), rng)
    X = 2.0 * (S - col_min) / rng_safe - 1.0
    # if constant column => set to 0
    if (rng == 0).any():
        X = X.clone()
        X[:, rng == 0] = 0.0

    # Legendre values per dimension:
    # leg_vals[j][k] is P_k(x_j), shape (n,)
    leg_vals: List[List[torch.Tensor]] = []
    for j in range(d):
        K = per_dim_degree_list[j]
        P_prev = torch.ones((n,), device=device, dtype=dtype)
        if K == 0:
            leg_vals.append([P_prev])
            continue

        xj = X[:, j]
        P_curr = xj.clone()
        cache = [P_prev, P_curr]
        for k in range(1, K):
            # recurrence:
            P_next = ((2 * k + 1) * xj * P_curr - k * P_prev) / (k + 1)
            cache.append(P_next)
            P_prev, P_curr = P_curr, P_next
        leg_vals.append(cache)

    # enumerate multi-index
    index_ranges = [range(per_dim_degree_list[j] + 1) for j in range(d)]
    all_multi_idx = product(*index_ranges)

    def keep_idx(m: Tuple[int, ...]) -> bool:
        if (not include_intercept) and all(v == 0 for v in m):
            return False
        if max_total_degree is not None and sum(m) > max_total_degree:
            return False
        if interaction_order is not None:
            nnz = sum(1 for v in m if v != 0)
            if nnz > interaction_order:
                return False
        return True

    filtered_idx = [m for m in all_multi_idx if keep_idx(m)]

    cols: List[torch.Tensor] = []
    names: List[str] = []

    if include_intercept:
        cols.append(torch.ones((n, 1), device=device, dtype=dtype))
        names.append("Intercept")

    for m in filtered_idx:
        if include_intercept and all(v == 0 for v in m):
            continue
        col = torch.ones((n,), device=device, dtype=dtype)
        parts = []
        for j, kj in enumerate(m):
            col = col * leg_vals[j][kj]
            if kj > 0:
                parts.append(f"P{kj}(s{j + 1})")
        name = "*".join(parts) if parts else "Intercept"
        cols.append(col.view(n, 1))
        names.append(name)

    Psi = torch.cat(cols, dim=1) if len(cols) > 0 else torch.empty((n, 0), device=device, dtype=dtype)
    return Psi, names, scaler









def make_psi_legendre_tensor_batch_torch(
    S: torch.Tensor,
    degree: int = 3,
    scaler: Optional[Dict[str, torch.Tensor]] = None,
    include_intercept: bool = True,
    max_total_degree: Optional[int] = None,
    interaction_order: Optional[int] = 1,
    per_dim_degree: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[str], Dict[str, torch.Tensor]]:
    """
    Torch batch version.
    Args:
      S: (B,n,d) or (n,d)
    Returns:
      Psi: (B,n,L)
    """
    if S.ndim == 2:
        S = S.unsqueeze(0)
    if S.ndim != 3:
        raise ValueError(f"S must be (B,n,d) or (n,d), got {tuple(S.shape)}")

    B, n, d = S.shape
    device, dtype = S.device, S.dtype

    # per-dimension degrees
    if per_dim_degree is None:
        per_dim_degree_list = [int(degree)] * d
    else:
        if isinstance(per_dim_degree, torch.Tensor):
            per_dim_degree_list = [int(x) for x in per_dim_degree.detach().cpu().tolist()]
        else:
            per_dim_degree_list = [int(x) for x in per_dim_degree]
        if len(per_dim_degree_list) != d:
            raise ValueError("per_dim_degree must have length d")
        if any(k < 0 for k in per_dim_degree_list):
            raise ValueError("per_dim_degree entries must be >= 0")

    # scaler: global over (B,n)
    if scaler is None:
        col_min = S.amin(dim=(0, 1))
        col_max = S.amax(dim=(0, 1))
        scaler = {"min": col_min, "max": col_max}
    else:
        col_min = torch.as_tensor(scaler["min"], device=device, dtype=dtype)
        col_max = torch.as_tensor(scaler["max"], device=device, dtype=dtype)
        if col_min.shape != (d,) or col_max.shape != (d,):
            raise ValueError("scaler['min'] and scaler['max'] must be shape (d,)")

    rng = col_max - col_min
    rng_safe = torch.where(rng == 0, torch.ones_like(rng), rng)
    X = 2.0 * (S - col_min.view(1, 1, d)) / rng_safe.view(1, 1, d) - 1.0
    if (rng == 0).any():
        X = X.clone()
        X[:, :, rng == 0] = 0.0

    # leg_vals[j][k]: (B,n)
    leg_vals: List[List[torch.Tensor]] = []
    for j in range(d):
        K = per_dim_degree_list[j]
        P0 = torch.ones((B, n), device=device, dtype=dtype)
        if K == 0:
            leg_vals.append([P0])
            continue

        xj = X[:, :, j]
        P1 = xj.clone()
        cache = [P0, P1]
        Pkm1, Pk = P0, P1
        for k in range(1, K):
            Pkp1 = ((2 * k + 1) * xj * Pk - k * Pkm1) / (k + 1)
            cache.append(Pkp1)
            Pkm1, Pk = Pk, Pkp1
        leg_vals.append(cache)

    # enumerate multi-index
    index_ranges = [range(per_dim_degree_list[j] + 1) for j in range(d)]
    all_multi_idx = product(*index_ranges)

    def keep_idx(m: Tuple[int, ...]) -> bool:
        if (not include_intercept) and all(v == 0 for v in m):
            return False
        if max_total_degree is not None and sum(m) > max_total_degree:
            return False
        if interaction_order is not None:
            nnz = sum(1 for v in m if v != 0)
            if nnz > interaction_order:
                return False
        return True

    filtered_idx = [m for m in all_multi_idx if keep_idx(m)]

    cols: List[torch.Tensor] = []
    names: List[str] = []

    if include_intercept:
        cols.append(torch.ones((B, n, 1), device=device, dtype=dtype))
        names.append("Intercept")

    for m in filtered_idx:
        if include_intercept and all(v == 0 for v in m):
            continue
        col = torch.ones((B, n), device=device, dtype=dtype)
        parts = []
        for j, kj in enumerate(m):
            col = col * leg_vals[j][kj]
            if kj > 0:
                parts.append(f"P{kj}(s{j + 1})")
        name = "*".join(parts) if parts else "Intercept"
        cols.append(col.unsqueeze(-1))
        names.append(name)

    Psi = torch.cat(cols, dim=2) if len(cols) > 0 else torch.empty((B, n, 0), device=device, dtype=dtype)
    return Psi, names, scaler
###

def est_u_tilde_torch(C_hat: torch.Tensor) -> Tuple[torch.Tensor, int]:
    """
    Torch version of est_u_tilde via SVD.
    Return U_tilde spanning the numerical null space of C_hat^T.

    Args:
      C_hat: (L, q) tensor (e.g., Psi^T X / n)

    Returns:
      U_tilde: (L, L-r) tensor
      r: rank (int)
    """
    if C_hat.ndim != 2:
        raise ValueError(f"C_hat must be 2D, got {tuple(C_hat.shape)}")

    # full_matrices=True => U: (L,L), S: (min(L,q),)
    U, S, Vh = torch.linalg.svd(C_hat, full_matrices=True)

    # tolerance similar to numpy:
    # tol = s.max() * max(shape) * eps
    if S.numel() == 0:
        r = 0
        U_tilde = U  # (L,L)
        return U_tilde, r

    eps = torch.finfo(C_hat.dtype).eps
    tol = S.max() * max(C_hat.shape) * eps
    r = int((S > tol).sum().item())

    U_tilde = U[:, r:]  # (L, L-r)
    return U_tilde, r

# ============================================================
# Ridge inverse utilities
# ============================================================


def gradual_ridge_inverse_torch(
    A: torch.Tensor,
    ridge_start: float = 1e-4,
    factor: float = 10.0,
    max_tries: int = 8,
    cond_threshold: float = 1e12,
) -> torch.Tensor:
    """
    Torch/GPU: try inv(A); if ill-conditioned or singular, add ridge progressively; fallback pinv.
    A: (d,d)
    """
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError("A must be a 2D square tensor")

    d = A.shape[0]
    device = A.device
    dtype = A.dtype

    try:
        cond = torch.linalg.cond(A).item()
    except Exception:
        cond = float("inf")

    if cond < cond_threshold:
        try:
            return torch.linalg.inv(A)
        except RuntimeError:
            pass

    ridge = float(ridge_start)
    I = torch.eye(d, device=device, dtype=dtype)

    for _ in range(max_tries):
        try:
            return torch.linalg.inv(A + ridge * I)
        except RuntimeError:
            ridge *= factor

    return torch.linalg.pinv(A)####








def torch_gradual_ridge_inverse_batched(
    A: torch.Tensor,
    ridge_start: float = 1e-4,
    factor: float = 10.0,
    max_tries: int = 3,
) -> torch.Tensor:
    """
    Batched version for A: (..., d, d).
    Strategy: apply same ridge to entire batch if inversion fails; fallback pinv.
    """
    if A.shape[-1] != A.shape[-2]:
        raise ValueError("A must be batched square matrices")
    d = A.shape[-1]
    device, dtype = A.device, A.dtype

    ridge = float(ridge_start)
    I = torch.eye(d, device=device, dtype=dtype)

    for _ in range(max_tries):
        try:
            return torch.linalg.inv(A + ridge * I)
        except RuntimeError:
            ridge *= float(factor)

    return torch.linalg.pinv(A)


# ============================================================
# Stable inverse for PSD-ish matrix
# ============================================================
def stable_inv_psd(A: torch.Tensor, ridge: float = 1e-4, max_tries: int = 8) -> torch.Tensor:
    A = 0.5 * (A + A.T)
    d = A.shape[0]
    I = torch.eye(d, device=A.device, dtype=A.dtype)
    r = ridge
    for _ in range(max_tries):
        try:
            inv = torch.linalg.inv(A + r * I)
            inv = 0.5 * (inv + inv.T)
            return inv
        except RuntimeError:
            r *= 10.0
    inv = torch.linalg.pinv(A + r * I)
    inv = 0.5 * (inv + inv.T)
    return inv
