# loc/mpc_monogp_ms.py
import os, csv
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import numpy as np
import torch

from MonotonicGP.monotonicGP import build_success_prob_fn_multisrc_monogp
from MonotonicGP.optimize import JointAdamCfg, solve_joint_horizon_torch_multi

VecPair = Tuple[Tuple[int, ...], float]

@dataclass
class ExperimentCfgMulti:
    T:int=3; Vstar:float=90.0; P:float=1e6
    q0_vec:List[int]=None; q_cap_vec:List[int]=None; c_vec:List[float]=None
    j_steps:int=400; minibatch_size:int=5; virt_bins:int=10; mu:float=1e-2; epochs:int=200
    num_inducing:int=100; num_direction:int=2
    j_lr:float=200.0; j_beta1:float=0.9; j_beta2:float=0.999; j_eps:float=1e-8; j_tol:float=1e-3

def append_curve_csv_multi(path: str, q_vec: List[int], V: float):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    exist = os.path.exists(path)
    with open(path,"a",newline="") as f:
        w=csv.writer(f)
        if not exist:
            header = [f"q{i+1}" for i in range(len(q_vec))] + ["V"]
            w.writerow(header)
        w.writerow([*q_vec, V])

def mpc_run_multi(initial_points: List[VecPair], oracle,
                  cfg: ExperimentCfgMulti, curve_out_csv: Optional[str]=None) -> Dict:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    points: List[VecPair] = sorted({(tuple(q), float(v)) for (q,v) in initial_points}, key=lambda x:x[0])

    q_prev_vec = list(cfg.q0_vec)
    q_cap_vec  = list(cfg.q_cap_vec)
    c_vec      = list(cfg.c_vec)

    oracle.ensure_collected(q_prev_vec)

    gp_state = None
    trace=[]; J_total=0.0
    for r in range(1, cfg.T+1):
        H = cfg.T - r + 1

        F, S, K, gp_state = build_success_prob_fn_multisrc_monogp(
            points, cfg.Vstar, q_cap_vec, gp_state=gp_state, device=device, dtype=torch.float32,
            minibatch_size=cfg.minibatch_size, virt_bins=cfg.virt_bins,
            num_direction=cfg.num_direction, num_inducing=cfg.num_inducing,
            mu=cfg.mu, epochs_init=cfg.epochs, epochs_update=cfg.epochs
        )

        joint_cfg = JointAdamCfg(
            steps=cfg.j_steps, lr=cfg.j_lr, beta1=cfg.j_beta1, beta2=cfg.j_beta2,
            eps=cfg.j_eps, tol_move=cfg.j_tol
        )

        plan = solve_joint_horizon_torch_multi(
            F, S, H=H, q_prev_vec=q_prev_vec, c_vec=c_vec, q_cap_vec=q_cap_vec,
            P=cfg.P, cfg=joint_cfg
        )

        q_next_vec = plan[0].tolist()
        print(q_next_vec)
        added_per_src = oracle.ensure_collected(q_next_vec)
        V_next = oracle(q_next_vec)
        
        points.append((tuple(int(x) for x in q_next_vec), float(V_next)))
        points = sorted({(tuple(q), float(v)) for (q,v) in points}, key=lambda x:x[0])
        if curve_out_csv:
            append_curve_csv_multi(curve_out_csv, q_next_vec, V_next)

        q_prev_t = torch.tensor(q_prev_vec, dtype=torch.float32, device=device)
        F_prev   = float(F(q_prev_t))
        J_round  = float(np.dot(np.array(q_next_vec)-np.array(q_prev_vec), np.array(c_vec)))*(1.0 - F_prev)

        trace.append({
            "round": r,
            "q_prev_vec": list(map(int, q_prev_vec)),
            "q_next_vec": list(map(int, q_next_vec)),
            "added_per_src": list(map(int, added_per_src)),
            "metric": float(V_next),
            "plan": plan.tolist(),
            "J_round_no_penalty": float(J_round),
        })
        J_total += J_round
        q_prev_vec = q_next_vec
        if V_next >= cfg.Vstar:
            break

    qT_t = torch.tensor(q_prev_vec, dtype=torch.float32, device=device)
    J_total += cfg.P * float(S(qT_t))

    return {"trace": trace, "qT": q_prev_vec, "metricT": trace[-1]["metric"],
            "curve": points, "objective_total": float(J_total)}
