import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
import argparse
import time
from pathlib import Path
from torch.distributions import Bernoulli
import os
import pandas as pd
from copy import deepcopy

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

try:
    from utils_self import (
        DGPConfig, generate_params, compute_index_u_vectors,
        trans_distribution, emission_distribution, time_features,
        monte_carlo_ATE_with_linear_baseline, sample_outcome_noise,
        OutcomeNoiseConfig, stable_inv_psd, make_psi_legendre_tensor_batch_torch
    )
except ImportError:
    print("no utils_self.py")
    raise


# ==============================================================================

# ==============================================================================
def solve_linear_regression_fast(X, y, X_test, ridge=1e-4):
    n = X.shape[0]
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    if X_test.ndim == 1:
        X_test = X_test.reshape(-1, 1)

    X_b = np.c_[np.ones(n), X]
    XTX = X_b.T @ X_b
    XTX[np.diag_indices_from(XTX)] += ridge
    XTy = X_b.T @ y

    try:
        beta = np.linalg.solve(XTX, XTy)
    except np.linalg.LinAlgError:
        beta = np.linalg.pinv(XTX) @ XTy

    X_test_b = np.c_[np.ones(X_test.shape[0]), X_test]
    return X_test_b @ beta

def save_one_run_sqerr_csv(results: dict, ate_true: float, seed: int, B: int, out_dir: str):
    """
    wide:  rep, base_seed, n_eval, ate_true, Standard, Robust, ...
    long:  rep, base_seed, n_eval, ate_true, method, sqerr
    """
    run_dir = os.path.join(out_dir, f"seed{seed}_B{B}")
    os.makedirs(run_dir, exist_ok=True)

    first_key = next(iter(results.keys()))
    n_rep = len(np.asarray(results[first_key]))

    df_wide = pd.DataFrame({
        "rep": np.arange(n_rep, dtype=int),
        "base_seed": int(seed),
        "n_eval": int(B),
        "ate_true": float(ate_true),
    })

    for method, arr in results.items():
        arr = np.asarray(arr, dtype=float)
        if len(arr) != n_rep:
            raise ValueError(f"Length mismatch: method={method}, len={len(arr)} != {n_rep}")
        df_wide[method] = (arr - ate_true) ** 2  #
    wide_path = os.path.join(run_dir, "sqerr_wide.csv")
    df_wide.to_csv(wide_path, index=False)
    print(f"[Saved per-rep squared error (wide)] {wide_path}")

    df_long = df_wide.melt(
        id_vars=["rep", "base_seed", "n_eval", "ate_true"],
        value_vars=[k for k in results.keys()],
        var_name="method",
        value_name="sqerr",
    )
    long_path = os.path.join(run_dir, "sqerr_long.csv")
    df_long.to_csv(long_path, index=False)
    print(f"[Saved per-rep squared error (long)] {long_path}")

    return wide_path
# ==============================================================================
# ==============================================================================
class MLPActorCritic(nn.Module):
    def __init__(self, d_state: int, hidden_sizes=(256, 256), activation=nn.GELU):
        super().__init__()
        layers = []
        in_dim = d_state
        for h in hidden_sizes:
            layers.append(nn.Linear(in_dim, h))
            layers.append(activation())
            in_dim = h
        self.encoder = nn.Sequential(*layers)
        self.pi_head = nn.Linear(in_dim, 1)

    def forward(self, state_vec):
        return self.pi_head(self.encoder(state_vec)).squeeze(-1), None


def load_actors_strict(N_days: int, d_state: int, device, checkpoint_dir: str, require_days_from: int = 2):
    actors = {}
    path = Path(checkpoint_dir).resolve()
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint directory not found: {path}")
    print(f"[Init] Strict loading RL models from {path} ...")

    for day in range(require_days_from, N_days + 1):
        model_path = (path / f"actor_day_{day}.pt").resolve()
        if not model_path.exists():
            raise FileNotFoundError(f"Missing checkpoint for day={day}: {model_path}")

        model = MLPActorCritic(d_state=d_state).to(device)
        ckpt = torch.load(model_path, map_location=device)
        sd = ckpt["actor_state_dict"] if isinstance(ckpt, dict) and "actor_state_dict" in ckpt else ckpt

        missing, unexpected = model.load_state_dict(sd, strict=False)
        if len(missing) > 0:
            raise RuntimeError(
                f"State dict missing keys in {model_path}: {missing[:10]}{'...' if len(missing)>10 else ''}"
            )

        model.eval()
        actors[day] = model

    missing_days = [d for d in range(require_days_from, N_days + 1) if d not in actors]
    if missing_days:
        raise RuntimeError(f"Not all actors loaded. Missing days: {missing_days}")

    print(f"[Init] Loaded {len(actors)} actors OK. Days={require_days_from}..{N_days}")
    return actors


# ==============================================================================
# ==============================================================================
def build_state_vector_standard(G1, G2, XP, AP, X, A, f, T):
    B = X.shape[0]
    device = X.device
    hist = torch.cat([XP, AP], dim=2).reshape(B, -1) if XP.shape[1] > 0 else torch.zeros(B, 0, device=device)
    target = (T - 1) * (X.shape[1] + 1)
    if hist.shape[1] < target:
        hist = torch.cat([hist, torch.zeros(B, target - hist.shape[1], device=device)], 1)
    return torch.cat([G1.reshape(B, -1), G2.reshape(B, -1), hist, X, f.view(1, -1).expand(B, -1)], 1)


def build_state_vector_robust(G1, G2, Cm, Vm, XP, AP, X, C, A, f, T, inc=False):
    B = X.shape[0]
    device = X.device
    hist = torch.cat([XP, AP], dim=2).reshape(B, -1) if XP.shape[1] > 0 else torch.zeros(B, 0, device=device)
    target = (T - 1) * (X.shape[1] + 1)
    if hist.shape[1] < target:
        hist = torch.cat([hist, torch.zeros(B, target - hist.shape[1], device=device)], 1)
    XA = torch.cat([X, A], 1) if inc else X
    return torch.cat(
        [G1.reshape(B, -1), G2.reshape(B, -1), Cm.reshape(B, -1), Vm.reshape(B, -1), C.reshape(B, -1),
         hist, XA, f.view(1, -1).expand(B, -1)],
        1
    )


# ==============================================================================
# ==============================================================================
def estimate_ate_from_trajectories_batch(Xf, Af, Yf, u_long, ridge=1e-4):
    device, dtype = Xf.device, Xf.dtype
    B, N, T, p = Xf.shape
    d1 = p + 1
    D = 2 * d1

    I = torch.eye(D, device=device, dtype=dtype).unsqueeze(0)
    ones = torch.ones((B, N, 1), device=device, dtype=dtype)

    ests_t = []
    for t in range(T):
        xt = Xf[:, :, t, :]
        at = Af[:, :, t, :]
        yt = Yf[:, :, t, :]

        X1 = torch.cat([ones, xt], dim=2)
        Z = torch.cat([X1, X1 * at], dim=2)

        ZTZ = torch.einsum("bnd,bne->bde", Z, Z) / N
        ZTy = torch.einsum("bnd,bnq->bdq", Z, yt) / N
        ZTZ = ZTZ + ridge * I

        beta = torch.linalg.solve(ZTZ, ZTy).squeeze(-1)  # [B,D]
        u = u_long[t].to(device=device, dtype=dtype).view(1, -1)  # [1,D]
        ests_t.append((beta * u).sum(dim=1))  # [B]

    return torch.stack(ests_t, dim=0).mean(dim=0)


# ==============================================================================
# ==============================================================================
def make_common_randomness(params, N_days, T, B, base_seed=2026):
    device = params.beta0.device
    dtype = params.beta0.dtype
    p = params.p

    eps_list = []
    for b in range(B):
        seed = base_seed + b * 100
        eps = sample_outcome_noise(
            N_days, T, OutcomeNoiseConfig(type="iid", sigma=1.0, rho=0.5),
            device, dtype, seed=seed
        )["eps"]
        if eps.ndim == 3 and eps.shape[-1] == 1:
            eps = eps[..., 0]
        eps_list.append(eps)
    eps_y = torch.stack(eps_list, dim=0)  # [B,N_days,T]

    g = torch.Generator(device=device)
    g.manual_seed(base_seed + 99991)

    X0 = torch.randn((B, N_days, p), device=device, dtype=dtype, generator=g)
    eps_x = torch.randn((B, N_days, T, p), device=device, dtype=dtype, generator=g)
    U_action = torch.rand((B, N_days, T), device=device, dtype=dtype, generator=g)

    return eps_y, X0, eps_x, U_action


# ==============================================================================
# ==============================================================================
@torch.no_grad()
def run_std_exp_batch_fair(params, actors, mode, N_days, u_long,
                           eps_y, X0, eps_x_raw, U_action,
                           ridge=1e-4, eval_skip_days=0):
    device = params.beta0.device
    dtype = params.beta0.dtype
    T, p = params.taus, params.p
    B = eps_y.shape[0]

    d1 = p + 1
    L = 10
    chol = torch.linalg.cholesky(params.eps_state_cov)

    G1 = torch.zeros((B, T, d1, d1), device=device, dtype=dtype)
    G2 = torch.zeros((B, T, d1, d1), device=device, dtype=dtype)
    Cm = torch.zeros((B, T, L, d1), device=device, dtype=dtype)
    Vm = torch.zeros((B, T, 2 * d1, L), device=device, dtype=dtype)

    X_days, A_days, Y_days = [], [], []

    sw_seed = None
    if mode == "switchback":
        u0 = U_action[:, 0, 0]
        sw_seed = torch.where(u0 < 0.5, torch.ones_like(u0), -torch.ones_like(u0)).to(dtype)  # [B]

    for day in range(1, N_days + 1):
        Xt = X0[:, day - 1, :].contiguous()

        Xp = torch.zeros((B, 0, p), device=device, dtype=dtype)
        Ap = torch.zeros((B, 0, 1), device=device, dtype=dtype)
        last_a = torch.zeros((B, 1), device=device, dtype=dtype)

        if mode == "switchback":
            day_sign = 1.0 if (day % 2 != 0) else -1.0
            day_base = sw_seed * day_sign  # [B]

        dX, dA, dY = [], [], []
        for t in range(T):
            Psi, _, _ = make_psi_legendre_tensor_batch_torch(Xt)
            Psi = Psi.reshape(B, -1)
            if Psi.shape[1] != L:
                raise RuntimeError(f"Psi dim mismatch: expected L={L}, got {tuple(Psi.shape)}")

            X1 = torch.cat([torch.ones((B, 1), device=device, dtype=dtype), Xt], dim=1)
            Cnow = Psi.unsqueeze(2) @ X1.unsqueeze(1)  # [B,L,d1]

            u = U_action[:, day - 1, t].view(B, 1)

    
            if mode == "switchback":
                act_val = day_base if (t % 2 == 0) else -day_base
                action = act_val.view(B, 1)

            elif mode == "random" or day == 1:
                action = torch.where(u < 0.5, torch.ones_like(u), -torch.ones_like(u)).to(dtype)

            else:
                if day not in actors:
                    raise KeyError(f"Missing actor for day={day}. Strict mode requires checkpoints day=2..N_days.")
                actor = actors[day]

                ft = time_features(t, T, device, dtype)
                if mode == "robust":
                    sv = build_state_vector_robust(G1[:, t], G2[:, t], Cm[:, t], Vm[:, t],
                                                   Xp, Ap, Xt, Cnow, last_a, ft, T)
                else:
                    sv = build_state_vector_standard(G1[:, t], G2[:, t], Xp, Ap, Xt, last_a, ft, T)

                logits = actor(sv)[0].view(B, 1)
                prob = torch.sigmoid(logits)
                action = torch.where(prob > u, torch.ones_like(u), -torch.ones_like(u)).to(dtype)

    
            eps = eps_y[:, day - 1, t].view(B, 1)
            mu = emission_distribution(params.beta0[t], params.beta1[t], params.xi1[t],
                                       Xt, action, params.gammas[t], "nonlinear")
            if mu.ndim == 1:
                mu = mu.unsqueeze(1)
            elif mu.ndim == 2 and mu.shape[1] != 1:
                raise RuntimeError(f"emission_distribution returned shape {tuple(mu.shape)}, expected (B,) or (B,1).")
            y = mu + eps

            dX.append(Xt)
            dA.append(action)
            dY.append(y)

            XA = X1 * action
            SS = X1.unsqueeze(2) @ X1.unsqueeze(1)
            ASS = XA.unsqueeze(2) @ X1.unsqueeze(1)

            G1[:, t] = (G1[:, t] * (day - 1) + SS + ASS) / day
            G2[:, t] = (G2[:, t] * (day - 1) + SS - ASS) / day
            Cm[:, t] = (Cm[:, t] * (day - 1) + Cnow) / day
            Vm[:, t] = (Vm[:, t] * (day - 1) + (torch.cat([X1, XA], dim=1).unsqueeze(2) @ Psi.unsqueeze(1))) / day

            state_noise = eps_x_raw[:, day - 1, t, :] @ chol.T
            Xt = trans_distribution(params.phi0[t], params.phi1[t], params.Xi1[t],
                                    Xt, action, params.alphas[t], state_noise)

            Xp = torch.cat([Xp, dX[-1].unsqueeze(1)], dim=1)
            Ap = torch.cat([Ap, dA[-1].unsqueeze(1)], dim=1)
            last_a = action

        if day > eval_skip_days:
            X_days.append(torch.stack(dX, dim=1))  # [B,T,p]
            A_days.append(torch.stack(dA, dim=1))  # [B,T,1]
            Y_days.append(torch.stack(dY, dim=1))  # [B,T,1]

    if len(X_days) == 0:
        return torch.zeros((B,), device=device, dtype=dtype)

    Xf = torch.stack(X_days, dim=1)  # [B,N_eff,T,p]
    Af = torch.stack(A_days, dim=1)  # [B,N_eff,T,1]
    Yf = torch.stack(Y_days, dim=1)  # [B,N_eff,T,1]
    return estimate_ate_from_trajectories_batch(Xf, Af, Yf, u_long, ridge=ridge)


# ==============================================================================
# ==============================================================================
@torch.no_grad()
def run_nmdp_experiment_fair(params, N_days, u_long,
                            eps_y_1, X0_1, eps_x_raw_1, U_1,
                            burn_in=5, ridge=1e-4):
    device = params.beta0.device
    dtype = params.beta0.dtype
    T, p = params.taus, params.p
    chol = torch.linalg.cholesky(params.eps_state_cov)

    X_hist_np, A_hist_np, G_hist_np = [], [], []
    XL, AL, YL = [], [], []

    def get_seed_action(curr_state_tensor, u_scalar):
        if len(X_hist_np) < 2:
            return 1.0

        S_arr = np.array(X_hist_np)
        A_arr = np.array(A_hist_np)
        G_arr = np.array(G_hist_np)
        s_curr = curr_state_tensor.detach().cpu().numpy().reshape(1, -1)

        def predict_sigma(target_a):
            mask = (A_arr == target_a)
            if np.sum(mask) < 2:
                return 1.0
            S_sub = S_arr[mask]
            G_sub = G_arr[mask]

            mu = solve_linear_regression_fast(S_sub, G_sub, S_sub, ridge=ridge)
            resid_sq = (G_sub.flatten() - mu.flatten()) ** 2
            log_var_target = np.log(resid_sq + 1e-12)
            log_var = solve_linear_regression_fast(S_sub, log_var_target, s_curr, ridge=ridge)[0]
            return float(np.sqrt(np.exp(log_var)))

        sig_pos = predict_sigma(1)
        sig_neg = predict_sigma(-1)
        prob = sig_pos / (sig_pos + sig_neg + 1e-6)
        return 1.0 if (u_scalar < prob) else -1.0

    for day in range(1, N_days + 1):
        Xt = X0_1[day - 1].view(1, p).to(device=device, dtype=dtype)
        u_seed = float(U_1[day - 1, 0].item())

        if day <= burn_in:
            seed = 1.0
        elif day <= 2 * burn_in:
            seed = -1.0
        else:
            seed = get_seed_action(Xt, u_seed)

        dX, dA, dY = [], [], []
        g_val = 0.0

        for t in range(T):
            act = torch.tensor([seed], device=device, dtype=dtype)
            eps = eps_y_1[day - 1, t].view(1).to(device=device, dtype=dtype)

            mu = emission_distribution(params.beta0[t], params.beta1[t], params.xi1[t],
                                       Xt, act.unsqueeze(1), params.gammas[t], "nonlinear")
            y = mu + eps
            dX.append(Xt)
            dA.append(act.unsqueeze(1))
            dY.append(y.view(1, 1))
            g_val += float(y.item())

            state_noise = (eps_x_raw_1[day - 1, t].view(1, p).to(device=device, dtype=dtype)) @ chol.T
            Xt = trans_distribution(params.phi0[t], params.phi1[t], params.Xi1[t],
                                    Xt, act.unsqueeze(1), params.alphas[t], state_noise)

        X_hist_np.append(dX[0].detach().cpu().numpy().flatten())
        A_hist_np.append(int(seed))
        G_hist_np.append(g_val)

        XL.append(torch.stack(dX, 1))
        AL.append(torch.stack(dA, 1))
        YL.append(torch.stack(dY, 1))

    Xf = torch.cat(XL, 0)
    Af = torch.cat(AL, 0)
    Yf = torch.cat(YL, 0)
    n_effective = Xf.shape[0]

    ests = []
    for t in range(T):
        xt, at, yt = Xf[:, t], Af[:, t], Yf[:, t]
        ones = torch.ones((n_effective, 1), device=device, dtype=dtype)
        Z = torch.cat([torch.cat([ones, xt], 1), torch.cat([ones, xt], 1) * at], 1)
        beta = stable_inv_psd((Z.T @ Z) / n_effective, ridge) @ ((Z.T @ yt) / n_effective)
        ests.append((u_long[t].view(-1, 1).T @ beta).item())
    return float(np.mean(ests))


@torch.no_grad()
def run_tmdp_experiment_fair(params, N_days, u_long,
                            eps_y_1, X0_1, eps_x_raw_1, U_1,
                            burn_in=5, ridge=1e-4):
    device = params.beta0.device
    dtype = params.beta0.dtype
    T, p = params.taus, params.p
    chol = torch.linalg.cholesky(params.eps_state_cov)

    XH_np, AH_np, RH_np = [], [], []
    XL, AL, YL = [], [], []

    def get_seed_action(u_scalar):
        if len(AH_np) < 2 * burn_in:
            return 1.0 if (u_scalar < 0.5) else -1.0

        Sh = np.concatenate(XH_np, axis=0)   # [N, T, p]
        Ah = np.concatenate(AH_np, axis=0)   # [N, T, 1]
        Rh = np.concatenate(RH_np, axis=0)   # [N, T, 1]

        sigmas = {0: 1.0, 1: 1.0}
        for a_idx, a_val in enumerate([-1, 1]):
            msk = (Ah[:, 0, 0] == a_val)
            if np.sum(msk) < 2:
                continue
            Ss = Sh[msk]
            Rs = Rh[msk]

            y_targets = [np.sum(Rs[:, t:, :], axis=1).flatten() for t in range(T)]
            v_preds = []
            for tt in range(T):
                X_train = Ss[:, tt, :]
                y_train = y_targets[tt]
                v_hat = solve_linear_regression_fast(X_train, y_train, X_train, ridge=ridge).flatten()
                v_preds.append(v_hat)

            err_sq = 0.0
            for tt in range(T):
                v_curr = v_preds[tt]
                v_next = v_preds[tt + 1] if tt < T - 1 else np.zeros_like(v_curr)
                td_err = Rs[:, tt, 0] + v_next - v_curr
                err_sq += np.sum(td_err ** 2)

            sigmas[a_idx] = np.sqrt(err_sq / np.sum(msk))

        prob = sigmas[1] / (sigmas[1] + sigmas[0] + 1e-6)
        return 1.0 if (u_scalar < prob) else -1.0

    for day in range(1, N_days + 1):
        Xt = X0_1[day - 1].view(1, p).to(device=device, dtype=dtype)
        u_seed = float(U_1[day - 1, 0].item())

        if day <= burn_in:
            seed = -1.0
        elif day <= 2 * burn_in:
            seed = 1.0
        else:
            seed = get_seed_action(u_seed)

        dX, dA, dY, dR = [], [], [], []
        for t in range(T):
            act = torch.tensor([seed], device=device, dtype=dtype)
            eps = eps_y_1[day - 1, t].view(1).to(device=device, dtype=dtype)

            y = emission_distribution(params.beta0[t], params.beta1[t], params.xi1[t],
                                      Xt, act.unsqueeze(1), params.gammas[t], "nonlinear") + eps

            dX.append(Xt)
            dA.append(act.unsqueeze(1))
            dY.append(y.view(1, 1))
            dR.append(y.view(1, 1))

            state_noise = (eps_x_raw_1[day - 1, t].view(1, p).to(device=device, dtype=dtype)) @ chol.T
            Xt = trans_distribution(params.phi0[t], params.phi1[t], params.Xi1[t],
                                    Xt, act.unsqueeze(1), params.alphas[t], state_noise)

        stk = lambda L: torch.stack(L, 1)
        XH_np.append(stk(dX).detach().cpu().numpy())
        AH_np.append(stk(dA).detach().cpu().numpy())
        RH_np.append(stk(dR).detach().cpu().numpy())

        XL.append(stk(dX))
        AL.append(stk(dA))
        YL.append(stk(dY))

    Xf = torch.cat(XL, 0)
    Af = torch.cat(AL, 0)
    Yf = torch.cat(YL, 0)
    n_effective = Xf.shape[0]

    ests = []
    for t in range(T):
        xt, at, yt = Xf[:, t], Af[:, t], Yf[:, t]
        ones = torch.ones((n_effective, 1), device=device, dtype=dtype)
        Z = torch.cat([torch.cat([ones, xt], 1), torch.cat([ones, xt], 1) * at], 1)
        beta = stable_inv_psd((Z.T @ Z) / n_effective, ridge) @ ((Z.T @ yt) / n_effective)
        ests.append((u_long[t].view(-1, 1).T @ beta).item())
    return float(np.mean(ests))


# ==============================================================================
# ==============================================================================
def run_evaluation(args):
    device = torch.device("cuda" if (torch.cuda.is_available() and args.device != "cpu") else "cpu")
    print(f"Running Evaluation on {device} (Fair=B, optional NMDP/TMDP)")

    cfg = DGPConfig(
        taus=args.taus, p=args.p, seed=2026, dtype=torch.float32, device=device,
        emission="nonlinear", phi1_range=0.3, Xi1_range=0.2,
        alpha_mean=0.0, alpha_sd=0.3, add_state_noise=True, round_digits=None
    )
    params = generate_params(cfg)
    params.add_state_noise = True
    u_long = compute_index_u_vectors(params)["U"].to(device)

    print("Computing True ATE...")
    res = monte_carlo_ATE_with_linear_baseline(
        params, Iter_n=[5000], emission="nonlinear", seed=2026,
        add_state_noise=True, act_hi=1.0, act_lo=-1.0, include_closed_form=True
    )
    ate_true = float(res[5000]["ATE_true_MC"])
    print(f"True ATE: {ate_true:.6f}")

    with torch.no_grad():
        d1 = args.p + 1
        dummy_G = torch.zeros((1, d1, d1), device=device)
        dummy_f = time_features(0, args.taus, device, torch.float32)
        dummy_X = torch.zeros((1, args.p), device=device)
        dummy_A = torch.zeros((1, 1), device=device)
        dummy_XP = torch.zeros((1, 0, args.p), device=device)
        dummy_AP = torch.zeros((1, 0, 1), device=device)
        dummy_C = torch.zeros((1, 10, d1), device=device)
        dummy_V = torch.cat([dummy_C, dummy_C], 2)

        ds = build_state_vector_standard(dummy_G, dummy_G, dummy_XP, dummy_AP, dummy_X, dummy_A, dummy_f, args.taus).shape[1]
        dr = build_state_vector_robust(dummy_G, dummy_G, dummy_C, dummy_V, dummy_XP, dummy_AP, dummy_X, dummy_C, dummy_A, dummy_f, args.taus, False).shape[1]

    astd = load_actors_strict(args.N_days, ds, device, args.std_ckpt, require_days_from=2)
    arob = load_actors_strict(args.N_days, dr, device, args.rob_ckpt, require_days_from=2)

    B = args.n_eval
    print(f"\n[Common RNG] base_seed={args.base_seed}, B={B} generating ...")
    eps_y, X0, eps_x_raw, U_action = make_common_randomness(params, args.N_days, args.taus, B, base_seed=args.base_seed)

    results = {}

    print("\n[Batch] Evaluating RL-Standard ...")
    t1 = time.time()
    results["Standard"] = run_std_exp_batch_fair(
        params, astd, "standard", args.N_days, u_long,
        eps_y, X0, eps_x_raw, U_action
    ).cpu().numpy()
    print(f"Done. time={time.time() - t1:.2f}s")

    print("\n[Batch] Evaluating RL-Robust ...")
    t2 = time.time()
    results["Robust"] = run_std_exp_batch_fair(
        params, arob, "robust", args.N_days, u_long,
        eps_y, X0, eps_x_raw, U_action
    ).cpu().numpy()
    print(f"Done. time={time.time() - t2:.2f}s")

    if not args.no_baselines:
        print("\n[Batch] Evaluating baselines (Random  / Switchback) ...")
        tb = time.time()
        results["Random"] = run_std_exp_batch_fair(
            params, {}, "random", args.N_days, u_long,
            eps_y, X0, eps_x_raw, U_action
        ).cpu().numpy()
        results["Switchback"] = run_std_exp_batch_fair(
            params, {}, "switchback", args.N_days, u_long,
            eps_y, X0, eps_x_raw, U_action
        ).cpu().numpy()
        print(f"Done. time={time.time() - tb:.2f}s")

    if not args.no_nmdp_tmdp:
        print(f"\n[NMDP/TMDP] Running (n_jobs={args.n_jobs}) ...")
        t3 = time.time()

        def one_run(b_idx: int):
            eps_y_1 = eps_y[b_idx].detach()
            X0_1 = X0[b_idx].detach()
            eps_x_1 = eps_x_raw[b_idx].detach()
            U_1 = U_action[b_idx].detach()

            nmdp = run_nmdp_experiment_fair(params, args.N_days, u_long,
                                           eps_y_1, X0_1, eps_x_1, U_1,
                                           burn_in=args.burn_in)
            tmdp = run_tmdp_experiment_fair(params, args.N_days, u_long,
                                           eps_y_1, X0_1, eps_x_1, U_1,
                                           burn_in=args.burn_in)
            return nmdp, tmdp

        idxs = list(range(B))
        if args.n_jobs > 1:
            try:
                from joblib import Parallel, delayed
                out = Parallel(n_jobs=args.n_jobs)(delayed(one_run)(i) for i in idxs)
            except Exception as e:
                print(f"[Warn] joblib parallel failed, fallback to serial. err={e}")
                out = [one_run(i) for i in idxs]
        else:
            out = [one_run(i) for i in idxs]

        results["NMDP"] = np.array([x[0] for x in out], dtype=float)
        results["TMDP"] = np.array([x[1] for x in out], dtype=float)
        print(f"Done. time={time.time() - t3:.2f}s")

    plt.figure(figsize=(14, 7))
    for name, arr in results.items():
        plt.hist(arr, bins=30, alpha=0.35, label=name, density=True)
    plt.axvline(ate_true, color="black", linestyle="--", label="True ATE")
    plt.title(f"ATE Estimation Histogram  seed={args.base_seed}, B={B}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(args.out_png, dpi=150)
    print(f"Histogram saved to {args.out_png}")

    return {"ate_true": ate_true, "results": results}


def compute_metrics(arr: np.ndarray, ate_true: float):
    arr = np.asarray(arr, dtype=float)
    return {
        "mean": float(arr.mean()),
        "std": float(arr.std()),
        "bias": float(arr.mean() - ate_true),
        "mse": float(((arr - ate_true) ** 2).mean()),
        "mse_CI_lower": float(((arr - ate_true) ** 2).mean())-1.96*((arr - ate_true) ** 2).std()/np.sqrt(len(arr)),
        "mse_CI_upper": float(((arr - ate_true) ** 2).mean())+1.96*((arr - ate_true) ** 2).std()/np.sqrt(len(arr)),

    }


def run_grid_and_plot(args,
                      base_seeds=(666),
                      n_evals=(1000),
                      out_dir="./grid_outputs"):
    os.makedirs(out_dir, exist_ok=True)

    rows = []
    ate_true_ref = None

    base_seeds = list(base_seeds)
    n_evals = list(n_evals)

    for seed in base_seeds:
        for B in n_evals:
            a = deepcopy(args)
            a.base_seed = int(seed)
            a.n_eval = int(B)
            a.out_png = os.path.join(out_dir, f"hist_seed{seed}_B{B}.png")

            print("\n" + "=" * 100)
            print(f"[GRID] base_seed={seed}, n_eval={B}")
            print("=" * 100)

            out = run_evaluation(a)
            ate_true = float(out["ate_true"])
            results = out["results"]
            save_one_run_sqerr_csv(results, ate_true, seed, B, out_dir)
            if ate_true_ref is None:
                ate_true_ref = ate_true
            else:
                if abs(ate_true - ate_true_ref) > 1e-8:
                    print(f"[Warn] ate_true changed across runs: {ate_true_ref} -> {ate_true}")

            for method, arr in results.items():
                m = compute_metrics(arr, ate_true)
                rows.append({
                    "base_seed": seed,
                    "n_eval": B,
                    "method": method,
                    "ate_true": ate_true,
                    **m
                })

    df = pd.DataFrame(rows)
    csv_path = os.path.join(out_dir, "grid_summary.csv")
    df.to_csv(csv_path, index=False)
    print(f"\n[Saved] {csv_path}")
    methods_order = ["Random", "Switchback", "NMDP", "TMDP", "Standard", "Robust"]
    present_methods = [m for m in methods_order if m in df["method"].unique()]
    seeds_sorted = sorted(df["base_seed"].unique())
    ncols = 2
    nrows = int(np.ceil(len(seeds_sorted) / ncols))
    plt.figure(figsize=(16, 5 * nrows))

    for i, seed in enumerate(seeds_sorted, 1):
        ax = plt.subplot(nrows, ncols, i)
        d0 = df[df["base_seed"] == seed]

        for method in present_methods:
            d1 = d0[d0["method"] == method].sort_values("n_eval")
            if len(d1) == 0:
                continue
            ax.plot(d1["n_eval"].values, d1["mse"].values, marker="o", label=method)

        ax.set_title(f"MSE vs n_eval (base_seed={seed})")
        ax.set_xlabel("n_eval")
        ax.set_ylabel("MSE")
        ax.grid(True, alpha=0.3)
        ax.set_xscale("log") 
        ax.legend()

    plt.tight_layout()
    fig_path = os.path.join(out_dir, "grid_mse_lines.png")
    plt.savefig(fig_path, dpi=160)
    print(f"[Saved] {fig_path}")

    return df


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--N_days", type=int, default=35)
    parser.add_argument("--n_eval", type=int, default=100)
    parser.add_argument("--taus", type=int, default=6)
    parser.add_argument("--p", type=int, default=3)
    parser.add_argument("--device", type=str, default="auto")
    parser.add_argument("--std_ckpt", type=str, default="./runs_1/final_clean")
    parser.add_argument("--rob_ckpt", type=str, default="./runs_robust/final_clean")

    parser.add_argument("--base_seed", type=int, default=2026)

    parser.add_argument("--no_baselines", action="store_true", help="no Random/Switchback")
    parser.add_argument("--no_nmdp_tmdp", action="store_true", help="no NMDP/TMDP")
    parser.add_argument("--n_jobs", type=int, default=6, help="NMDP/TMDP(need joblib)")
    parser.add_argument("--burn_in", type=int, default=5)

    parser.add_argument("--out_png", type=str, default="eval_hist.png")
    parser.add_argument("--grid_out_dir", type=str, default="./grid_outputs_CI_details")

    args = parser.parse_args()
    base_seeds = [666]
    n_evals = [1000]

    run_grid_and_plot(args, base_seeds=base_seeds, n_evals=n_evals, out_dir=args.grid_out_dir)
