from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import torch
import numpy as np

from LOC.model import fit_gmm_success_prob
from LOC.curve import append_curve_ms_csv, MSPair



def objective_ms_cdf(F, Q, q_prev, c_vec, P):
    H, S = Q.shape
    J = torch.zeros((), dtype=Q.dtype, device=Q.device)
    prev = q_prev
    for t in range(H):
        delta = Q[t] - prev
        J = J + (c_vec * delta).sum() * (1.0 - F(prev))
        prev = Q[t]
    return J + P * (1.0 - F(Q[-1]))

def solve_joint_horizon_multisrc_cdf(F, H, q_prev_vec, c_vec, q_cap_vec, P,
                                    init=None, steps=400, lr=200.0, betas=(0.9,0.999),
                                    eps=1e-8, tol_move=1e-3, device=None, dtype=torch.float32):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    q_prev = torch.as_tensor(q_prev_vec, dtype=dtype, device=device)  # [S]
    q_cap  = torch.as_tensor(q_cap_vec,  dtype=dtype, device=device)  # [S]
    cvec   = torch.as_tensor(c_vec,      dtype=dtype, device=device)  # [S]

    lo = q_prev
    hi = q_cap
    hi = torch.tensor(q_cap_vec, dtype=dtype, device=device)       # [K]
 
    S = q_prev.numel()

  
    if init is None:
        Q0 = torch.stack([torch.linspace(lo[k].item(), hi[k].item(), H, dtype=dtype, device=device)
                          for k in range(len(q_prev_vec))], dim=1)     
    else:
        Q0 = torch.tensor(init, dtype=dtype, device=device)       # [H, K]
        Q0 = torch.minimum(torch.maximum(Q0, lo.unsqueeze(0)), hi.unsqueeze(0))


    D_raw0 = torch.zeros_like(Q0)
    D_raw = torch.nn.Parameter(D_raw0.clone())  # [H,K]
    opt = torch.optim.Adam([D_raw], lr=lr, betas=betas, eps=eps)

    for _ in range(steps):
        opt.zero_grad(set_to_none=True)

        # delta >= 0
        delta = torch.nn.functional.softplus(D_raw)          # [H,S]

        # Q monotone by construction
        Q = q_prev.unsqueeze(0) + torch.cumsum(delta, dim=0) # [H,S]

        # enforce upper bound (preserves monotonicity)
        Q = torch.min(Q, hi.unsqueeze(0))

        loss = objective_ms_cdf(F, Q, q_prev, cvec, P)
        
        Q_old = Q.detach().clone()

        loss.backward()
        opt.step()



    with torch.no_grad():
        delta = torch.nn.functional.softplus(D_raw)
        Q = q_prev.unsqueeze(0) + torch.cumsum(delta, dim=0)
        Q = torch.min(Q, hi.unsqueeze(0))

        Qi = torch.ceil(Q).to(torch.int64)
        for s in range(S):
            Qi[:, s].clamp_(min=int(lo[s].item()), max=int(hi[s].item()))
            for t in range(1, H):
                Qi[t, s] = torch.maximum(Qi[t, s], Qi[t-1, s])
    return Qi.cpu().numpy()

@dataclass
class ExperimentCfg:
    T:int=3; Vstar:float=90.0; P:float=1e6
    q0_vec: Tuple[int,...] = (5000, 5000)
    qcap_vec: Tuple[int,...] = (50000, 50000)
    c_vec: Tuple[float,...] = (1.0, 1.0)
    steps:int=400; lr:float=200.0; beta1:float=0.9; beta2:float=0.999; eps:float=1e-8; tol:float=1e-3
    gmm_K:int=3

def mpc_run_ms(initial_points: List[MSPair], oracle, cfg: ExperimentCfg,
               curve_out_csv: Optional[str]=None, seed: Optional[int]=None) -> Dict:
    S = len(cfg.q0_vec)
    points = sorted({(tuple(q), float(v)) for q,v in initial_points}, key=lambda x: x[0])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype  = torch.float32

    def build_F(pts: List[MSPair], round_r: int):
        seed_F = None if seed is None else int(seed)*1000 + int(round_r)
        return fit_gmm_success_prob(
            pts, Vstar=cfg.Vstar, S=S, K=cfg.gmm_K,
            device=device, dtype=dtype,
            c_vec=cfg.c_vec, q_cap_vec=cfg.qcap_vec,
            B=400, seed=seed_F
        )

    # init to q0
    oracle.ensure_collected(list(cfg.q0_vec))
    q_prev = oracle.current_q_vec()

    trace=[]; J_total=0.0
    for r in range(1, cfg.T+1):
        H = cfg.T - r + 1
        F = build_F(points, r)
        
        plan = solve_joint_horizon_multisrc_cdf(
            F, H=H, q_prev_vec=q_prev, c_vec=cfg.c_vec, q_cap_vec=cfg.qcap_vec, P=cfg.P,
            steps=cfg.steps, lr=cfg.lr, betas=(cfg.beta1,cfg.beta2), eps=cfg.eps, tol_move=cfg.tol,
            device=device, dtype=dtype
        )
        q_next = plan[0].tolist()
        print(q_next)
        added = oracle.ensure_collected(q_next)
        V_next = oracle(q_next)

        points.append((tuple(q_next), float(V_next)))
        points = sorted({(tuple(q), float(v)) for (q,v) in points}, key=lambda x:x[0])
        if curve_out_csv:
            append_curve_ms_csv(curve_out_csv, q_next, V_next)

        trace.append({"round": r, "q_prev": q_prev, "q_next": q_next, "added_vec": added, "metric": float(V_next), "plan": plan.tolist()})
        q_prev = q_next

        if V_next >= cfg.Vstar:
            break

    return {"trace": trace, "qT": q_prev, "metricT": trace[-1]["metric"], "curve": points}
