from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import torch





def _objective_torch_multi(F, S, Q: torch.Tensor, q_prev_vec: torch.Tensor, c_vec: torch.Tensor, P: float) -> torch.Tensor:
    """
    J(Q) = sum_t c^T (q_t - q_{t-1}) (1 - F(q_{t-1})) + P * S(q_H)
    where S(q) = E[(V* - f(q))_+].
    """
    H, K = Q.shape
    J = torch.zeros((), dtype=Q.dtype, device=Q.device)
    prev = q_prev_vec
    for t in range(H):
        F_prev = F(prev)                      # scalar
        delta  = (Q[t] - prev)                # [K]
        J = J + (delta * c_vec).sum() * (1.0 - F_prev)
        prev = Q[t]
    # <-- replace terminal penalty with shortfall:
    J = J + P * S(Q[-1])
    return J



# ============================================================
# ---------------- Torch Adam solvers ------------------------
# ============================================================
@dataclass
class JointAdamCfg:
    steps:int=400; lr:float=200.0; beta1:float=0.9; beta2:float=0.999; eps:float=1e-8; tol_move:float=1e-3

def solve_joint_horizon_torch_multi(F, S, H:int, q_prev_vec: List[int], c_vec: List[float], q_cap_vec: List[int],
                                    P: float, init: Optional[List[List[int]]] = None,
                                    cfg: JointAdamCfg = JointAdamCfg(),
                                    device: Optional[torch.device] = None,
                                    dtype: torch.dtype = torch.float32) -> np.ndarray:
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    K = len(q_prev_vec)
    q_prev = torch.tensor(q_prev_vec, dtype=dtype, device=device)  # [K]
    lo = q_prev                                                    # [K]
    hi = torch.tensor(q_cap_vec, dtype=dtype, device=device)       # [K]
    c  = torch.tensor(c_vec, dtype=dtype, device=device)           # [K]

    if init is None:
        Q0 = torch.stack([torch.linspace(lo[k].item(), hi[k].item(), H, dtype=dtype, device=device)
                          for k in range(K)], dim=1)              # [H, K]
    else:
        Q0 = torch.tensor(init, dtype=dtype, device=device)       # [H, K]
        Q0 = torch.minimum(torch.maximum(Q0, lo.unsqueeze(0)), hi.unsqueeze(0))





    D_raw0 = torch.zeros_like(Q0)
    D_raw = torch.nn.Parameter(D_raw0.clone())  # [H,K]

    opt = torch.optim.Adam([D_raw], lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), eps=cfg.eps)

    for _ in range(cfg.steps):
        opt.zero_grad(set_to_none=True)

        # delta >= 0
        delta = torch.nn.functional.softplus(D_raw)               # [H,K]

        # cumulative plan (monotone by construction)
        Q = q_prev.unsqueeze(0) + torch.cumsum(delta, dim=0)      # [H,K]

        # enforce box upper bound (keeps monotonicity: min with constant hi preserves nondecreasing)
        Q = torch.minimum(Q, hi.unsqueeze(0))

        loss = _objective_torch_multi(F, S, Q, q_prev, c, P)
        loss.backward(retain_graph=True)
        opt.step()

    with torch.no_grad():
        delta = torch.nn.functional.softplus(D_raw)
        Q = q_prev.unsqueeze(0) + torch.cumsum(delta, dim=0)
        Q = torch.minimum(Q, hi.unsqueeze(0))

        Q_int = torch.ceil(Q).to(torch.int64)
        Q_int = torch.minimum(torch.maximum(Q_int, q_prev.to(torch.int64).unsqueeze(0)),
                              torch.tensor(q_cap_vec, dtype=torch.int64, device=Q.device).unsqueeze(0))
        # monotone safety after integerization
        for k in range(K):
            for t in range(1, H):
                if Q_int[t, k] < Q_int[t-1, k]:
                    Q_int[t, k] = Q_int[t-1, k]
    return Q_int.cpu().numpy()




