
from __future__ import annotations

import os
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.distributions import Bernoulli

from utils_self import *

# ==============================================================================
# 0. Small perf knobs
# ==============================================================================

def set_perf_flags(device: torch.device):
    # TF32 speeds up matmul on Ampere+ (A100/3090/4090 etc.) with small accuracy loss
    if device.type == "cuda":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # helps when sizes vary a bit; can be slower on some cases, but often helps
        torch.backends.cudnn.benchmark = True


# ==============================================================================
# 1. Network & State Builder
# ==============================================================================

class MLPActorCritic(nn.Module):
    def __init__(self, d_state: int, hidden_sizes=(256, 256), activation=nn.GELU, value_hidden=128):
        super().__init__()
        layers = []
        in_dim = d_state
        for h in hidden_sizes:
            layers.append(nn.Linear(in_dim, h))
            layers.append(activation())
            in_dim = h
        self.encoder = nn.Sequential(*layers)
        self.pi_head = nn.Linear(in_dim, 1)
        self.v_head = nn.Sequential(
            nn.Linear(in_dim, value_hidden),
            activation(),
            nn.Linear(value_hidden, 1),
            nn.Softplus(),  
        )

    def forward(self, state_vec: torch.Tensor):
        h = self.encoder(state_vec)
        logits = self.pi_head(h).squeeze(-1)
        value = -self.v_head(h).squeeze(-1)  
        return logits, value


def build_state_vector_hiera_batch_final(
    G1_summary_batch, G2_summary_batch, C_mean_t, v_mean_t,
    X_past_batch, A_past_batch,
    X_batch, C_now, A_batch,
    f_t, T,
    include_current_action: bool = False,
):
    """
    """
    B = X_batch.shape[0]
    device = X_batch.device

    d1 = G1_summary_batch.shape[1]
    dimG = d1 * d1
    G1_flat = G1_summary_batch.reshape(B, dimG)
    G2_flat = G2_summary_batch.reshape(B, dimG)

    C_mean_flat = C_mean_t.reshape(B, -1)
    C_now_flat  = C_now.reshape(B, -1)
    v_mean_flat = v_mean_t.reshape(B, -1)

    # History flatten + pad
    t_minus_1 = X_past_batch.shape[1]
    max_hist = T - 1
    hist_dim = max_hist * (X_batch.shape[1] + 1)

    if t_minus_1 > 0:
        past_flat = torch.cat([X_past_batch, A_past_batch], dim=2).reshape(B, -1)
    else:
        past_flat = torch.zeros((B, 0), device=device, dtype=X_batch.dtype)

    if past_flat.shape[1] < hist_dim:
        pad_len = hist_dim - past_flat.shape[1]
        padding = torch.zeros((B, pad_len), device=device, dtype=X_batch.dtype)
        hist_flat = torch.cat([past_flat, padding], dim=1)
    else:
        hist_flat = past_flat

    # Current state
    if include_current_action:
        XA = torch.cat([X_batch, A_batch], dim=1)
    else:
        XA = X_batch

    f_cur = f_t.view(1, -1).expand(B, -1)

    return torch.cat(
        [G1_flat, G2_flat, C_mean_flat, v_mean_flat, C_now_flat, hist_flat, XA, f_cur],
        dim=1
    )


# ==============================================================================
# ==============================================================================

@torch.no_grad()
def stable_solve_psd_batch(G: torch.Tensor, u: torch.Tensor, ridge: float = 1e-4, max_tries: int = 8):

    G = 0.5 * (G + G.transpose(-1, -2))
    B, d, _ = G.shape
    I = torch.eye(d, device=G.device, dtype=G.dtype).expand(B, d, d)

    r = torch.full((B,), ridge, device=G.device, dtype=G.dtype)
    alive = torch.ones((B,), device=G.device, dtype=torch.bool)

    w = torch.zeros_like(u)

    for _ in range(max_tries):
        if not alive.any():
            break

        idx = alive.nonzero(as_tuple=False).squeeze(-1)
        G_try = G[idx] + r[idx].view(-1, 1, 1) * I[idx]
        u_try = u[idx]

        L, info = torch.linalg.cholesky_ex(G_try)  # info==0 success
        ok = (info == 0)

        if ok.any():
            w_ok = torch.cholesky_solve(u_try[ok], L[ok])
            w[idx[ok]] = w_ok
            alive[idx[ok]] = False

        if (~ok).any():
            r[idx[~ok]] = r[idx[~ok]] * 10.0

    if alive.any():
        idx = alive.nonzero(as_tuple=False).squeeze(-1)
        G_reg = G[idx] + r[idx].view(-1, 1, 1) * I[idx]
        w[idx] = torch.linalg.pinv(G_reg) @ u[idx]
        alive[idx] = False

    return w, r


@torch.no_grad()
def nullspace_projector_batch(C_hat: torch.Tensor):
    """
    Build nullspace projector P_null via batch SVD:
      P_null = I - U_r U_r^T, where U_r are singular vectors corresponding to nonzero singular values.
    C_hat: (B, L, q) or (L, q)
    Return:
      P_null: (B, L, L)
      r: (B,)
    """
    if C_hat.ndim == 2:
        C_hat = C_hat.unsqueeze(0)
    if C_hat.ndim != 3:
        raise ValueError(f"C_hat must be (B,L,q) or (L,q), got {tuple(C_hat.shape)}")

    B, L, q = C_hat.shape

    U, S, _ = torch.linalg.svd(C_hat, full_matrices=True)  # U: (B,L,L)
    eps = torch.finfo(C_hat.dtype).eps
    tol = S.max(dim=-1).values * max(L, q) * eps
    r = (S > tol[:, None]).sum(dim=-1)

    col_idx = torch.arange(L, device=C_hat.device)[None, :]
    mask = (col_idx < r[:, None]).to(U.dtype)

    U_sig = U * mask[:, None, :]
    P_range = U_sig @ U_sig.transpose(-1, -2)

    I = torch.eye(L, device=C_hat.device, dtype=C_hat.dtype).expand(B, L, L)
    P_null = I - P_range
    return P_null, r


# ==============================================================================
# ==============================================================================

@torch.no_grad()
def gen_nimins_one_days_random(*, params, B: int, N_days: int, device=None, dtype=None, seed: int = 2026, add_state_noise: bool = True):
    if device is None: device = params.beta0.device
    if dtype is None:  dtype  = params.beta0.dtype

    m = N_days
    if m <= 0: return {}

    g = torch.Generator(device=device); g.manual_seed(seed)
    T, p = params.taus, params.p

    A01 = torch.randint(0, 2, (B, m, T), device=device, generator=g).to(dtype)
    A = A01 * 2 - 1.0

    X = torch.zeros((B, m, T, p), device=device, dtype=dtype)
    X[:, :, 0, :] = torch.randn((B, m, p), device=device, dtype=dtype, generator=g)

    Psi_all = torch.zeros((B, m, T, 3*p+1), device=device, dtype=dtype)
    flat_input = X[:, :, 0, :].flatten(0, 1)
    flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
    Psi_all[:, :, 0, :] = flat_psi.view(B, m, -1)

    chol_state = torch.linalg.cholesky(params.eps_state_cov) if add_state_noise else None

    for t in range(T - 1):
        X_t = X[:, :, t, :].reshape(-1, p)
        A_t = A[:, :, t].reshape(-1, 1)
        state_noise = (torch.randn((B*m, p), device=device, dtype=dtype, generator=g) @ chol_state.T) if add_state_noise else None

        X_next = trans_distribution(
            params.phi0[t], params.phi1[t], params.Xi1[t], X_t, A_t, params.alphas[t], state_noise
        )

        X[:, :, t + 1, :] = X_next.view(B, m, p)
        flat_input = X[:, :, t+1, :].flatten(0, 1)
        flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
        Psi_all[:, :, t+1, :] = flat_psi.view(B, m, -1)

    ones = torch.ones((B, m, T, 1), device=device, dtype=dtype)
    S = torch.cat([ones, X], dim=-1)
    AS = S * A.unsqueeze(-1)
    Z_S_AS = torch.cat([S, AS], dim=-1)

    psi_col = Psi_all.unsqueeze(-1)
    s_row = S.unsqueeze(-2)
    C_matrix = psi_col @ s_row
    C_mean = torch.mean(C_matrix, dim=1)

    v = torch.einsum("bmti, bmtj -> bmtij", Z_S_AS, Psi_all)
    v_sum = v.sum(dim=1)

    SS = torch.einsum("bmtk,bmtl->bmtkl", S, S)
    ASS = SS * A[..., None, None]
    return {"G1": SS.mean(dim=1) + ASS.mean(dim=1),
            "G2": SS.mean(dim=1) - ASS.mean(dim=1),
            "C_mean": C_mean,
            "v_sum": v_sum}


@torch.no_grad()
def gen_nimins_all_T_random(*, params, B: int, N_days: int, device=None, dtype=None, seed: int = 2026, add_state_noise: bool = True):
    if device is None: device = params.beta0.device
    if dtype is None:  dtype  = params.beta0.dtype

    m = N_days
    if m <= 0: return {}

    g = torch.Generator(device=device); g.manual_seed(seed)
    T, p = params.taus, params.p

    day_decisions = torch.randint(0, 2, (B, m, 1), device=device, generator=g)
    A01 = day_decisions.expand(-1, -1, T)
    A = A01.to(dtype) * 2 - 1.0

    X = torch.zeros((B, m, T, p), device=device, dtype=dtype)
    X[:, :, 0, :] = torch.randn((B, m, p), device=device, dtype=dtype, generator=g)

    Psi_all = torch.zeros((B, m, T, 3*p+1), device=device, dtype=dtype)
    flat_input = X[:, :, 0, :].flatten(0, 1)
    flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
    Psi_all[:, :, 0, :] = flat_psi.view(B, m, -1)

    chol_state = torch.linalg.cholesky(params.eps_state_cov) if add_state_noise else None

    for t in range(T - 1):
        X_t = X[:, :, t, :].reshape(-1, p)
        A_t = A[:, :, t].reshape(-1, 1)
        state_noise = (torch.randn((B*m, p), device=device, dtype=dtype, generator=g) @ chol_state.T) if add_state_noise else None

        X_next = trans_distribution(
            params.phi0[t], params.phi1[t], params.Xi1[t], X_t, A_t, params.alphas[t], state_noise
        )
        X[:, :, t + 1, :] = X_next.view(B, m, p)

        flat_input = X[:, :, t+1, :].flatten(0, 1)
        flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
        Psi_all[:, :, t+1, :] = flat_psi.view(B, m, -1)

    ones = torch.ones((B, m, T, 1), device=device, dtype=dtype)
    S = torch.cat([ones, X], dim=-1)
    AS = S * A.unsqueeze(-1)
    Z_S_AS = torch.cat([S, AS], dim=-1)

    psi_col = Psi_all.unsqueeze(-1)
    s_row = S.unsqueeze(-2)
    C_matrix = psi_col @ s_row
    C_mean = torch.mean(C_matrix, dim=1)

    v = torch.einsum("bmti, bmtj -> bmtij", Z_S_AS, Psi_all)
    v_sum = v.sum(dim=1)

    SS = torch.einsum("bmtk,bmtl->bmtkl", S, S)
    ASS = SS * A[..., None, None]

    return {"G1": SS.mean(dim=1) + ASS.mean(dim=1),
            "G2": SS.mean(dim=1) - ASS.mean(dim=1),
            "C_mean": C_mean,
            "v_sum": v_sum}


@torch.no_grad()
def gen_nimins_one_days_switchback(*, params, B: int, N_days: int, device=None, dtype=None, seed: int = 2026, add_state_noise: bool = True):
    if device is None: device = params.beta0.device
    if dtype is None:  dtype  = params.beta0.dtype

    m = N_days
    if m <= 0: return {}

    g = torch.Generator(device=device); g.manual_seed(seed)
    T, p = params.taus, params.p

    A = torch.zeros((B, m, T), device=device, dtype=dtype)
    A0 = torch.randint(0, 2, (B,), device=device, generator=g).to(dtype) * 2 - 1.0
    A[:, 0, 0] = A0
    for t in range(1, T):
        A[:, 0, t] = -A[:, 0, t - 1]
    for i in range(1, m):
        A[:, i, :] = -A[:, i - 1, :]

    X = torch.zeros((B, m, T, p), device=device, dtype=dtype)
    X[:, :, 0, :] = torch.randn((B, m, p), device=device, dtype=dtype, generator=g)

    Psi_all = torch.zeros((B, m, T, 3*p+1), device=device, dtype=dtype)
    flat_input = X[:, :, 0, :].flatten(0, 1)
    flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
    Psi_all[:, :, 0, :] = flat_psi.view(B, m, -1)

    chol_state = torch.linalg.cholesky(params.eps_state_cov) if add_state_noise else None

    for t in range(T - 1):
        X_t = X[:, :, t, :].reshape(-1, p)
        A_t = A[:, :, t].reshape(-1, 1)
        state_noise = (torch.randn((B*m, p), device=device, dtype=dtype, generator=g) @ chol_state.T) if add_state_noise else None

        X_next = trans_distribution(
            params.phi0[t], params.phi1[t], params.Xi1[t], X_t, A_t, params.alphas[t], state_noise
        )
        X[:, :, t + 1, :] = X_next.view(B, m, p)

        flat_input = X[:, :, t+1, :].flatten(0, 1)
        flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
        Psi_all[:, :, t+1, :] = flat_psi.view(B, m, -1)

    ones = torch.ones((B, m, T, 1), device=device, dtype=dtype)
    S = torch.cat([ones, X], dim=-1)
    AS = S * A.unsqueeze(-1)
    Z_S_AS = torch.cat([S, AS], dim=-1)

    psi_col = Psi_all.unsqueeze(-1)
    s_row = S.unsqueeze(-2)
    C_matrix = psi_col @ s_row
    C_mean = torch.mean(C_matrix, dim=1)

    SS = torch.einsum("bmtk,bmtl->bmtkl", S, S)
    ASS = SS * A[..., None, None]

    v = torch.einsum("bmti, bmtj -> bmtij", Z_S_AS, Psi_all)
    v_sum = v.sum(dim=1)

    return {"G1": SS.mean(dim=1) + ASS.mean(dim=1),
            "G2": SS.mean(dim=1) - ASS.mean(dim=1),
            "C_mean": C_mean,
            "v_sum": v_sum}


@torch.no_grad()
def gen_nimins_all_T_half(*, params, B: int, N_days: int, device=None, dtype=None, seed: int = 2026, add_state_noise: bool = True):
    if device is None: device = params.beta0.device
    if dtype is None:  dtype  = params.beta0.dtype

    m = N_days
    if m <= 0: return {}

    g = torch.Generator(device=device); g.manual_seed(seed)
    T, p = params.taus, params.p

    A = torch.ones((B, m, T), device=device, dtype=dtype)
    A[:, 1::2, :] = -1.0

    X = torch.zeros((B, m, T, p), device=device, dtype=dtype)
    X[:, :, 0, :] = torch.randn((B, m, p), device=device, dtype=dtype, generator=g)

    Psi_all = torch.zeros((B, m, T, 3*p+1), device=device, dtype=dtype)
    flat_input = X[:, :, 0, :].flatten(0, 1)
    flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
    Psi_all[:, :, 0, :] = flat_psi.view(B, m, -1)

    chol_state = torch.linalg.cholesky(params.eps_state_cov) if add_state_noise else None

    for t in range(T - 1):
        X_t = X[:, :, t, :].reshape(-1, p)
        A_t = A[:, :, t].reshape(-1, 1)
        state_noise = (torch.randn((B*m, p), device=device, dtype=dtype, generator=g) @ chol_state.T) if add_state_noise else None

        X_next = trans_distribution(
            params.phi0[t], params.phi1[t], params.Xi1[t], X_t, A_t, params.alphas[t], state_noise
        )
        X[:, :, t + 1, :] = X_next.view(B, m, p)

        flat_input = X[:, :, t+1, :].flatten(0, 1)
        flat_psi, _, _ = make_psi_legendre_tensor_torch(flat_input)
        Psi_all[:, :, t+1, :] = flat_psi.view(B, m, -1)

    ones = torch.ones((B, m, T, 1), device=device, dtype=dtype)
    S = torch.cat([ones, X], dim=-1)
    AS = S * A.unsqueeze(-1)
    Z_S_AS = torch.cat([S, AS], dim=-1)

    psi_col = Psi_all.unsqueeze(-1)
    s_row = S.unsqueeze(-2)
    C_matrix = psi_col @ s_row
    C_mean = torch.mean(C_matrix, dim=1)

    v = torch.einsum("bmti, bmtj -> bmtij", Z_S_AS, Psi_all)
    v_sum = v.sum(dim=1)

    SS = torch.einsum("bmtk,bmtl->bmtkl", S, S)
    ASS = SS * A[..., None, None]

    return {"G1": SS.mean(dim=1) + ASS.mean(dim=1),
            "G2": SS.mean(dim=1) - ASS.mean(dim=1),
            "C_mean": C_mean,
            "v_sum": v_sum}

@torch.no_grad()
@torch.no_grad()
def generate_initial_data_mixed_tensorpool(*, params, n_minus_k, device, dtype, sample_number=5000, seed=2026):
    """
    TensorPool replay buffer: return dict of big tensors on GPU.
    """
    if n_minus_k <= 0:
        return None

    d_rnd  = gen_nimins_one_days_random(params=params, B=sample_number, N_days=n_minus_k, device=device, dtype=dtype, seed=seed)
    d_sb   = gen_nimins_one_days_switchback(params=params, B=sample_number, N_days=n_minus_k, device=device, dtype=dtype, seed=seed + 1)
    d_half = gen_nimins_all_T_half(params=params, B=sample_number, N_days=n_minus_k, device=device, dtype=dtype, seed=seed + 2)
    d_day  = gen_nimins_all_T_random(params=params, B=sample_number, N_days=n_minus_k, device=device, dtype=dtype, seed=seed + 3)

    G1_list = [d_rnd["G1"], d_sb["G1"], d_half["G1"], d_day["G1"]]
    G2_list = [d_rnd["G2"], d_sb["G2"], d_half["G2"], d_day["G2"]]
    C_list  = [d_rnd["C_mean"], d_sb["C_mean"], d_half["C_mean"], d_day["C_mean"]]
    V_list  = [d_rnd["v_sum"],  d_sb["v_sum"],  d_half["v_sum"],  d_day["v_sum"]]
    G1 = torch.cat(G1_list, dim=0)
    G2 = torch.cat(G2_list, dim=0)
    C_mean = torch.cat(C_list, dim=0)
    v_sum  = torch.cat(V_list, dim=0)

    N = G1.shape[0]
    perm = torch.randperm(N, device=device)
    G1 = G1[perm]
    G2 = G2[perm]
    C_mean = C_mean[perm]
    v_sum = v_sum[perm]


    return {"G1": G1, "G2": G2, "C_mean": C_mean, "v_sum": v_sum}

# ==============================================================================
# 4. Reward & Rollout
# ==============================================================================

@torch.no_grad()
def compute_reward_rollout_linear_batch(
    *,
    G1_accum_batch, G2_accum_batch,
    C_mean_batch, v_sum_batch_t,
    u_t,
    eta_t,
    T_length,
    ridge=1e-4,
    N_days: int
):
    B, d1, _ = G1_accum_batch.shape
    d = 2 * d1

    top_left  = G1_accum_batch + G2_accum_batch
    top_right = G1_accum_batch - G2_accum_batch
    bot_left  = G1_accum_batch - G2_accum_batch
    bot_right = G1_accum_batch + G2_accum_batch

    G_Big = torch.cat([
        torch.cat([top_left,  top_right], dim=-1),
        torch.cat([bot_left,  bot_right], dim=-1)
    ], dim=-2)

    u = u_t.view(1, d, 1).expand(B, d, 1)
    w_t, _ = stable_solve_psd_batch(G_Big, u, ridge=ridge)

    quad = (u.transpose(1, 2) @ w_t).reshape(B)

    v_transposed = v_sum_batch_t.transpose(-1, -2)   # (B, L, 2d1)
    M_t = v_transposed @ w_t                         # (B, L, 1)

    L_basis = C_mean_batch.shape[1]
    P_null, _ = nullspace_projector_batch(C_mean_batch)

    proj = P_null @ M_t
    norm_1d = torch.linalg.norm(proj.squeeze(-1), dim=-1)**2

    reward_part2 = -(eta_t ** 2) * (L_basis / (N_days ** 2)) * norm_1d * T_length
    return (-quad / N_days + reward_part2) / (T_length ** 2)


def rollout_n_k_plus_1_to_n_days(
    *,
    actor_now, actor_future_dict,
    G1_sum_batch, G2_sum_batch, C_mean_batch, v_sum_batch,
    params, eta_t, u_long_mat,
    batch_size, current_day, N_days,
    device, dtype,
    debug: bool = False,
):
    T, p = params.taus, params.p

    states_actor = []
    states_critic = []
    actions = []
    rewards = []

    G1_curr = G1_sum_batch.clone()
    G2_curr = G2_sum_batch.clone()
    C_mean_curr = C_mean_batch.clone()
    v_sum_curr = v_sum_batch.clone()

    # pre-made ones for this rollout
    ones_B1 = torch.ones((batch_size, 1), device=device, dtype=dtype)

    for day in range(current_day, N_days + 1):
        is_training_day = (day == current_day)
        actor_net = actor_now if is_training_day else actor_future_dict[day]
        if not is_training_day:
            actor_net.eval()

        # Pre-alloc histories
        X_hist = torch.empty((batch_size, T, p), device=device, dtype=dtype)
        A_hist = torch.empty((batch_size, T, 1), device=device, dtype=dtype)

        X_batch = torch.randn((batch_size, p), device=device, dtype=dtype)
        last_action = torch.zeros((batch_size,), device=device, dtype=dtype)

        X1 = torch.cat([ones_B1, X_batch], dim=1)

        Psi_torch, _, _ = make_psi_legendre_tensor_batch_torch(X_batch)
        Psi_torch = Psi_torch.squeeze()
        C_now = Psi_torch.unsqueeze(2) @ X1.unsqueeze(1)

        chol_state = torch.linalg.cholesky(params.eps_state_cov)

        for t in range(T):
            G1_t = G1_curr[:, t]
            G2_t = G2_curr[:, t]
            C_mean_t = C_mean_curr[:, t]
            v_sum_curr_t = v_sum_curr[:, t]
            v_mean_t = v_sum_curr_t / float(day)

            f_t = time_features(t, T, device, dtype)

            if t == 0:
                X_past = torch.zeros((batch_size, 0, p), device=device, dtype=dtype)
                A_past = torch.zeros((batch_size, 0, 1), device=device, dtype=dtype)
            else:
                X_past = X_hist[:, :t, :]
                A_past = A_hist[:, :t, :]

            s_actor = build_state_vector_hiera_batch_final(
                G1_t, G2_t, C_mean_t, v_mean_t,
                X_past, A_past,
                X_batch, C_now,
                last_action.unsqueeze(1),
                f_t, T, include_current_action=False
            )

            if is_training_day:
                logits, _ = actor_net(s_actor)
                dist = Bernoulli(logits=logits)

                # debug only once per rollout step-set (t==0) to avoid spam
                if debug and t == 0:
                    p_mean = torch.sigmoid(logits).mean()
                    ent = dist.entropy().mean()
                    # no .item() here; let caller decide, but printing requires sync anyway:
                    print(f"[debug] day={current_day} p(a01=1) mean={p_mean.item():.3f} | entropy={ent.item():.3f}")

                a01 = dist.sample()

                states_actor.append(s_actor)
                actions.append(a01)

                # Critic sees CURRENT action
                action_pm = a01 * 2 - 1.0
                s_critic = build_state_vector_hiera_batch_final(
                    G1_t, G2_t, C_mean_t, v_mean_t,
                    X_past, A_past,
                    X_batch, C_now,
                    action_pm.unsqueeze(1),
                    f_t, T, include_current_action=True
                )
                states_critic.append(s_critic)
            else:
                with torch.no_grad():
                    logits, _ = actor_net(s_actor)
                    dist = Bernoulli(logits=logits)
                    a01 = dist.sample()

            action = a01 * 2 - 1.0

            # Save histories
            X_hist[:, t, :] = X_batch
            A_hist[:, t, 0] = action

            # Env noise
            state_noise = (torch.randn((batch_size, p), device=device, dtype=dtype) @ chol_state.T)

            # Update stats
            X1 = torch.cat([ones_B1, X_batch], dim=1)
            XA = X1 * action.unsqueeze(1)

            term_SS = X1.unsqueeze(2) @ X1.unsqueeze(1)
            term_ASS = XA.unsqueeze(2) @ X1.unsqueeze(1)

            Psi_torch, _, _ = make_psi_legendre_tensor_batch_torch(X_batch)
            Psi_torch = Psi_torch.squeeze()

            Z_S_AS = torch.cat([X1, XA], dim=-1)
            v_update_part = Z_S_AS.unsqueeze(2) @ Psi_torch.unsqueeze(1)
            v_sum_curr[:, t] += v_update_part

            C_now = Psi_torch.unsqueeze(2) @ X1.unsqueeze(1)

            prev_n, curr_n = float(day - 1), float(day)
            G1_curr[:, t] = (G1_curr[:, t] * prev_n + (term_SS + term_ASS)) / curr_n
            G2_curr[:, t] = (G2_curr[:, t] * prev_n + (term_SS - term_ASS)) / curr_n
            C_mean_curr[:, t] = (C_mean_curr[:, t] * prev_n + C_now) / curr_n

            last_action = action
            if t < T - 1:
                X_batch = trans_distribution(
                    params.phi0[t], params.phi1[t], params.Xi1[t],
                    X_batch, action.unsqueeze(1),
                    params.alphas[t],
                    state_noise
                )

    # Final rewards
    for t in range(T):
        r_t = compute_reward_rollout_linear_batch(
            G1_accum_batch=G1_curr[:, t],
            G2_accum_batch=G2_curr[:, t],
            C_mean_batch=C_mean_curr[:, t],
            v_sum_batch_t=v_sum_curr[:, t],
            u_t=u_long_mat[t],
            eta_t=eta_t,
            T_length=T,
            ridge=1e-4,
            N_days=N_days
        )
        rewards.append(r_t)

    return (
        torch.stack(states_actor, 1),
        torch.stack(states_critic, 1),
        torch.stack(actions, 1),
        torch.stack(rewards, 1),
    )


def update_A2C_batch(
    *,
    actor, critic,
    optimizer_actor, optimizer_critic,
    states_actor, states_critic,
    actions, rewards,
    gamma=1.0,
    entropy_coef=0.01,
    value_coef=0.5,
    grad_clip=1.0,
    normalize_adv: bool = True,
):
    """
    Return tensors (no .item) to reduce GPU synchronization.
    """
    B, T, _ = states_actor.shape
    returns = torch.zeros_like(rewards)
    Gt = torch.zeros((B,), device=states_actor.device)

    for t in reversed(range(T)):
        Gt = rewards[:, t] + gamma * Gt
        returns[:, t] = Gt

    logits, _ = actor(states_actor.view(-1, states_actor.shape[-1]))
    logits = logits.view(B, T)
    dist = Bernoulli(logits=logits)
    logps = dist.log_prob(actions)
    entropy = dist.entropy().mean()

    _, values = critic(states_critic.view(-1, states_critic.shape[-1]))
    values = values.view(B, T)

    advantages = returns - values.detach()
    if normalize_adv:
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    actor_loss = -(logps * advantages).mean() - entropy_coef * entropy
    critic_loss = value_coef * (returns - values).pow(2).mean()

    optimizer_actor.zero_grad(set_to_none=True)
    actor_loss.backward()
    nn.utils.clip_grad_norm_(actor.parameters(), grad_clip)
    optimizer_actor.step()

    optimizer_critic.zero_grad(set_to_none=True)
    critic_loss.backward()
    nn.utils.clip_grad_norm_(critic.parameters(), grad_clip)
    optimizer_critic.step()

    return actor_loss, critic_loss, rewards.mean()


# ==============================================================================
# 5. Training Loop (TensorPool + reduced sync)
# ==============================================================================

def train_A2C_for_n_k_plus_1(
    *,
    actor_now, actor_future_dict,
    critic,
    optimizer_actor, optimizer_critic,
    summary_pool,  # dict of tensors or None
    params, eta_t, u_long_mat,
    n_episodes,
    current_day, N_days,
    batch_size=32,
    save_actor=True,
    save_dir="./checkpoints",
    print_every=50,
):
    device = params.beta0.device
    dtype = params.beta0.dtype

    loss_list_t = []
    actor_loss_list_t = []
    critic_loss_list_t = []
    reward_mean_list_t = []

    print(f"--- [Day {current_day}] Start. Future: {list(actor_future_dict.keys())} ---")

    if device.type == "cuda":
        torch.cuda.synchronize()
    t0 = time.perf_counter()

    for ep in range(n_episodes):
        if summary_pool is None:
            d1 = params.p + 1
            L_basis = 3 * params.p + 1
            G1_batch = torch.zeros((batch_size, params.taus, d1, d1), device=device, dtype=dtype)
            G2_batch = torch.zeros((batch_size, params.taus, d1, d1), device=device, dtype=dtype)
            C_mean_batch = torch.zeros((batch_size, params.taus, L_basis, d1), device=device, dtype=dtype)
            v_sum_batch = torch.zeros((batch_size, params.taus, 2 * d1, L_basis), device=device, dtype=dtype)
        else:
            N = summary_pool["G1"].shape[0]
            idx = torch.randint(0, N, (batch_size,), device=device)
            G1_batch = summary_pool["G1"][idx]
            G2_batch = summary_pool["G2"][idx]
            C_mean_batch = summary_pool["C_mean"][idx]
            v_sum_batch = summary_pool["v_sum"][idx]

        debug_now = ((ep + 1) % print_every == 0)

        states_a, states_c, acts, rews = rollout_n_k_plus_1_to_n_days(
            actor_now=actor_now,
            actor_future_dict=actor_future_dict,
            G1_sum_batch=G1_batch,
            G2_sum_batch=G2_batch,
            C_mean_batch=C_mean_batch,
            v_sum_batch=v_sum_batch,
            params=params,
            eta_t=eta_t,
            u_long_mat=u_long_mat,
            batch_size=batch_size,
            current_day=current_day,
            N_days=N_days,
            device=device,
            dtype=dtype,
            debug=debug_now,
        )

        a_loss_t, c_loss_t, avg_r_t = update_A2C_batch(
            actor=actor_now,
            critic=critic,
            optimizer_actor=optimizer_actor,
            optimizer_critic=optimizer_critic,
            states_actor=states_a,
            states_critic=states_c,
            actions=acts,
            rewards=rews
        )

        total_loss_t = a_loss_t + c_loss_t

        loss_list_t.append(total_loss_t.detach())
        actor_loss_list_t.append(a_loss_t.detach())
        critic_loss_list_t.append(c_loss_t.detach())
        reward_mean_list_t.append(avg_r_t.detach())

        if debug_now:
            print(f"[Day {current_day} | Ep {ep+1}] "
                  f"R: {avg_r_t.item():.4f} | Total: {total_loss_t.item():.4f} | "
                  f"Act: {a_loss_t.item():.4f} | Crit: {c_loss_t.item():.4f}")

    if device.type == "cuda":
        torch.cuda.synchronize()
    dt = time.perf_counter() - t0
    print(f"--- [Day {current_day}] Done. Train time: {dt:.2f}s ({dt/60:.2f} min) ---")

    if save_actor:
        os.makedirs(save_dir, exist_ok=True)
        torch.save(actor_now.state_dict(), f"{save_dir}/actor_day_{current_day}.pt")

    # Move lists to CPU once
    reward_mean_list = torch.stack(reward_mean_list_t).cpu().numpy().tolist()
    loss_list = torch.stack(loss_list_t).cpu().numpy().tolist()
    actor_loss_list = torch.stack(actor_loss_list_t).cpu().numpy().tolist()
    critic_loss_list = torch.stack(critic_loss_list_t).cpu().numpy().tolist()

    return reward_mean_list, loss_list, actor_loss_list, critic_loss_list


@torch.no_grad()

def plot_training_results(day, r, l, a_l, c_l, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle(f"Training Results - Day {day}")

    axs[0, 0].plot(r);   axs[0, 0].set_title("Reward");      axs[0, 0].grid(True)
    axs[0, 1].plot(l);   axs[0, 1].set_title("Total Loss");  axs[0, 1].grid(True)
    axs[1, 0].plot(a_l); axs[1, 0].set_title("Actor Loss");  axs[1, 0].grid(True)
    axs[1, 1].plot(c_l); axs[1, 1].set_title("Critic Loss"); axs[1, 1].grid(True)

    plt.tight_layout()
    plt.savefig(f"{save_dir}/day_{day}.png")
    plt.close()


def train_all_layers(params, u_long_mat, d_actor, d_critic, N_days, eta_t, n_episodes, batch_size, lr, save_dir, device, print_every=50):
    print(f"\n{'='*60}")
    print(f"STARTING BACKWARD INDUCTION (Day {N_days} -> 2)")
    print(f"{'='*60}\n")

    trained_actors = {}

    # Phase 1: Last Day
    print(f">>> [Phase 1] Training Last Day {N_days}")
    pool = generate_initial_data_mixed_tensorpool(
        params=params, n_minus_k=N_days - 1, device=device, dtype=params.beta0.dtype
    )

    actor = MLPActorCritic(d_actor).to(device)
    critic = MLPActorCritic(d_critic).to(device)
    opt_a = torch.optim.Adam(actor.parameters(), lr=lr)
    opt_c = torch.optim.Adam(critic.parameters(), lr=lr)

    r, l, al, cl = train_A2C_for_n_k_plus_1(
        actor_now=actor,
        actor_future_dict={},
        critic=critic,
        optimizer_actor=opt_a,
        optimizer_critic=opt_c,
        summary_pool=pool,
        params=params,
        eta_t=eta_t,
        u_long_mat=u_long_mat,
        n_episodes=n_episodes,
        current_day=N_days,
        N_days=N_days,
        batch_size=batch_size,
        save_dir=save_dir,
        print_every=print_every,
    )
    plot_training_results(N_days, r, l, al, cl, save_dir)
    actor.eval()
    trained_actors[N_days] = actor

    # Phase 2: Day N-1 -> 2
    for day in range(N_days - 1, 1, -1):
        print(f"\n>>> [Phase 2] Training Day {day}")
        pool = generate_initial_data_mixed_tensorpool(
            params=params, n_minus_k=day - 1, device=device, dtype=params.beta0.dtype
        )

        actor = MLPActorCritic(d_actor).to(device)
        critic = MLPActorCritic(d_critic).to(device)
        opt_a = torch.optim.Adam(actor.parameters(), lr=lr)
        opt_c = torch.optim.Adam(critic.parameters(), lr=lr)

        r, l, al, cl = train_A2C_for_n_k_plus_1(
            actor_now=actor,
            actor_future_dict=trained_actors,
            critic=critic,
            optimizer_actor=opt_a,
            optimizer_critic=opt_c,
            summary_pool=pool,
            params=params,
            eta_t=eta_t,
            u_long_mat=u_long_mat,
            n_episodes=n_episodes,
            current_day=day,
            N_days=N_days,
            batch_size=batch_size,
            save_dir=save_dir,
            print_every=print_every,
        )
        plot_training_results(day, r, l, al, cl, save_dir)
        actor.eval()
        trained_actors[day] = actor

    return trained_actors


# ==============================================================================
# 6. Main
# ==============================================================================

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--taus", type=int, default=6)
    parser.add_argument("--p", type=int, default=3)
    parser.add_argument("--n_episodes", type=int, default=300)
    parser.add_argument("--device", type=str, default="auto")
    parser.add_argument("--print_every", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--eta_t", type=float, default=1.0)
    parser.add_argument("--N_days", type=int, default=35)
    parser.add_argument("--save_dir", type=str, default="./runs_robust/final_clean")
    parser.add_argument("--compile", action="store_true", help="use torch.compile for actor/critic (PyTorch2+)")
    args = parser.parse_args()

    if args.device == "auto":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)  # "cuda:0" / "cuda:1" / "cpu"

    print(f"Running on {device}")
    set_perf_flags(device)

    cfg = DGPConfig(
        taus=args.taus, p=args.p, seed=2026, dtype=torch.float32, device=device,
        emission="linear", phi1_range=0.3, Xi1_range=0.2,
        alpha_mean=0.0, alpha_sd=0.3,
        add_state_noise=True, round_digits=None
    )
    params = generate_params(cfg)
    params.add_state_noise = True

    u_dict = compute_index_u_vectors(params)
    u_long_mat = u_dict["U"].to(device)

    # Detect dims
    with torch.no_grad():
        d1 = args.p + 1
        L_basis = 3 * args.p + 1

        dummy_G = torch.zeros((1, d1, d1), device=device)
        dummy_C = torch.zeros((1, L_basis, d1), device=device)
        dummy_V = torch.zeros((1, 2 * d1, L_basis), device=device)

        dummy_X = torch.zeros((1, args.p), device=device)
        dummy_A = torch.zeros((1, 1), device=device)
        dummy_f = time_features(0, args.taus, device, torch.float32)
        dummy_XP = torch.zeros((1, 0, args.p), device=device)
        dummy_AP = torch.zeros((1, 0, 1), device=device)

        state_vec = build_state_vector_hiera_batch_final(
            dummy_G, dummy_G, dummy_C, dummy_V,
            dummy_XP, dummy_AP,
            dummy_X, dummy_C, dummy_A,
            dummy_f, args.taus, False
        )
        d_actor = state_vec.shape[1]

        state_vec_c = build_state_vector_hiera_batch_final(
            dummy_G, dummy_G, dummy_C, dummy_V,
            dummy_XP, dummy_AP,
            dummy_X, dummy_C, dummy_A,
            dummy_f, args.taus, True
        )
        d_critic = state_vec_c.shape[1]

    print(f"Dims Detected: Actor={d_actor}, Critic={d_critic}")

    actor_tmp = MLPActorCritic(d_actor).to(device)
    critic_tmp = MLPActorCritic(d_critic).to(device)
    if args.compile and hasattr(torch, "compile"):
        actor_tmp = torch.compile(actor_tmp)
        critic_tmp = torch.compile(critic_tmp)

    del actor_tmp, critic_tmp

    train_all_layers(
        params=params,
        u_long_mat=u_long_mat,
        d_actor=d_actor,
        d_critic=d_critic,
        N_days=args.N_days,
        eta_t=args.eta_t,
        n_episodes=args.n_episodes,
        batch_size=args.batch_size,
        lr=args.lr,
        save_dir=args.save_dir,
        device=device,
        print_every=args.print_every
    )

    print("\n All Training Completed!")
