import os
import math
import time
import argparse
from pathlib import Path
from collections import deque, namedtuple
import logging
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from typing import Callable, List, Dict, Any, Optional, Tuple

parser = argparse.ArgumentParser()
parser.add_argument("--transition_model", type=str, default="checkpoints/transition_model_with_epoch_100.pth")
parser.add_argument("--revenue_model",    type=str, default="checkpoints/revenue_model_with_epoch_100.pth")
parser.add_argument("--initial_dist",     type=str, default="initial_state_distribution.csv")
parser.add_argument("--residual_csv",     type=str, default="nn_residual_analysis.csv")
parser.add_argument("--transformer_path",type=Path,default=Path("checkpoints/transformer.pt"), help="Transformer saved path")
parser.add_argument("--num_envs", type=int, default=16)
parser.add_argument("--times",    type=int, default=20)
parser.add_argument("--dates",    type=int, default=35)
parser.add_argument("--batch_size",  type=int, default=32)
parser.add_argument("--replay_capacity", type=int, default=150000)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--gamma", type=float, default=0.95)
parser.add_argument("--target_update_steps", type=int, default=1200)
parser.add_argument("--eps_start", type=float, default=0.9)
parser.add_argument("--eps_end",   type=float, default=0.3)
parser.add_argument("--eps_decay_steps", type=int, default=500000)
parser.add_argument("--warmup_days", type=int, default=10)
parser.add_argument("--ate_mc_sims_init", type=int, default=20000)
parser.add_argument("--save_path", type=str, default="policy_dqn_transformer_perstep.pth")
parser.add_argument("--posenc_max_len", type=int, default=8192)
parser.add_argument("--max_chunk_mc", type=int, default=5000)
parser.add_argument("--epochs", type=int, default=1500)
parser.add_argument("--gpu", type=int, default=7)
parser.add_argument("--seed", type=int, default=2025)
args = parser.parse_args()

def init_logger(name="DQN-Transformer-PerStep",
                log_dir=None,
                fname="train.log",
                level=logging.INFO,
                when="midnight",
                backup_count=7,
                overwrite_on_start=False):
    from pathlib import Path
    import logging, logging.handlers
    if log_dir is None:
        log_dir = Path.cwd() / "trainlog"
    else:
        log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    log_path = log_dir / fname
    if overwrite_on_start and log_path.exists():
        try:
            log_path.unlink()
        except Exception:
            open(log_path, "w").close()
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False
    for h in list(logger.handlers):
        logger.removeHandler(h)
    fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s", "%Y-%m-%d %H:%M:%S")
    ch = logging.StreamHandler(); ch.setLevel(level); ch.setFormatter(fmt); logger.addHandler(ch)
    fh = logging.handlers.TimedRotatingFileHandler(
        filename=str(log_path), when=when, backupCount=backup_count, encoding="utf-8", delay=True
    )
    fh.setLevel(level); fh.setFormatter(fmt); logger.addHandler(fh)
    logger.info(f"Logger initialized. logging to {log_path}")
    return logger

logger = init_logger(name="DQN-Transformer-PerStep", overwrite_on_start=True)

assert torch.cuda.is_available(), "GPU 不可用"
device = torch.device(f"cuda:{args.gpu}")
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

Transition = namedtuple('Transition', ('history', 'history_mask', 'action', 'reward', 'next_history', 'next_history_mask', 'done'))

class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=64):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
    def forward(self, x): return self.layers(x)

def map_state_dict_prefix(state_dict, from_prefix, to_prefix):
    mapped = {}
    for k, v in state_dict.items():
        if k.startswith(from_prefix):
            newk = to_prefix + k[len(from_prefix):]
        else:
            newk = k
        mapped[newk] = v
    return mapped

def robust_load_model(model, path, model_name="model"):
    sd_loaded = torch.load(str(path), map_location="cpu")
    if isinstance(sd_loaded, dict) and ('state_dict' in sd_loaded or 'model_state_dict' in sd_loaded):
        if 'state_dict' in sd_loaded:
            sd = sd_loaded['state_dict']
        else:
            sd = sd_loaded['model_state_dict']
    else:
        sd = sd_loaded
    keys = list(sd.keys())
    try:
        model.load_state_dict(sd)
        logger.info(f"{model_name}: strict load ok from {path}")
        return sd
    except Exception as e:
        logger.warning(f"{model_name}: strict load failed: {e}")
    try:
        res = model.load_state_dict(sd, strict=False)
        logger.info(f"{model_name}: loaded with strict=False; missing_keys={res.missing_keys}; unexpected_keys={res.unexpected_keys}")
        return sd
    except Exception as e:
        logger.warning(f"{model_name}: loose load failed: {e}")
    if any(k.startswith("module.") for k in keys) and not any(k.startswith("module.") for k in model.state_dict().keys()):
        sd2 = map_state_dict_prefix(sd, "module.", "")
        try:
            model.load_state_dict(sd2)
            logger.info(f"{model_name}: success after stripping 'module.'")
            return sd2
        except Exception as e:
            logger.warning(f"{model_name}: strip module failed: {e}")
    if any(k.startswith("layers.") for k in keys) and not any(k.startswith("net.") for k in keys):
        sd2 = map_state_dict_prefix(sd, "layers.", "net.")
        try:
            model.load_state_dict(sd2)
            logger.info(f"{model_name}: success mapping layers.->net.")
            return sd2
        except Exception as e:
            logger.warning(f"{model_name}: mapping layers->net failed: {e}")
    if any(k.startswith("net.") for k in keys) and not any(k.startswith("layers.") for k in keys):
        sd2 = map_state_dict_prefix(sd, "net.", "layers.")
        try:
            model.load_state_dict(sd2)
            logger.info(f"{model_name}: success mapping net.->layers.")
            return sd2
        except Exception as e:
            logger.warning(f"{model_name}: mapping net->layers failed: {e}")
    sample_k = keys[:80] if len(keys) > 80 else keys
    logger.error(f"{model_name}: failed to load checkpoint {path}. sample keys: {sample_k}")
    raise RuntimeError(f"{model_name}: failed to load checkpoint {path}. See logs.")

def load_required_assets_or_die():
    if not Path(args.transition_model).exists(): raise FileNotFoundError(f"{args.transition_model}")
    if not Path(args.revenue_model).exists(): raise FileNotFoundError(f"{args.revenue_model}")
    if not Path(args.initial_dist).exists(): raise FileNotFoundError(f"{args.initial_dist}")
    if not Path(args.residual_csv).exists(): raise FileNotFoundError(f"{args.residual_csv}")
    transition_model = MLP(4,2).to(device)
    robust_load_model(transition_model, args.transition_model, model_name="transition_model")
    transition_model.eval()
    revenue_model = MLP(4,1).to(device)
    robust_load_model(revenue_model, args.revenue_model, model_name="revenue_model")
    revenue_model.eval()
    df_init = pd.read_csv(args.initial_dist)
    if not {'orders','probability'}.issubset(df_init.columns):
        raise ValueError("initial_state_distribution.csv 必须包含 orders 和 probability 列")
    initial_orders_values = torch.tensor(df_init['orders'].values, dtype=torch.float32, device=device)
    initial_orders_probs  = torch.tensor(df_init['probability'].values, dtype=torch.float32, device=device)
    initial_orders_probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)
    df_r = pd.read_csv(args.residual_csv)
    need_cols = ['ordersNext_resid_mean','ordersNext_resid_var','driversNext_resid_mean','driversNext_resid_var','revenue_resid_mean','revenue_resid_var']
    if not set(need_cols).issubset(df_r.columns):
        raise ValueError(f"缺少列: {need_cols}")
    orders_resid_mean = torch.tensor(df_r['ordersNext_resid_mean'].values, dtype=torch.float32, device=device)
    orders_resid_var  = torch.tensor(df_r['ordersNext_resid_var'].values, dtype=torch.float32, device=device)
    drivers_resid_mean= torch.tensor(df_r['driversNext_resid_mean'].values, dtype=torch.float32, device=device)
    drivers_resid_var = torch.tensor(df_r['driversNext_resid_var'].values, dtype=torch.float32, device=device)
    revenue_resid_mean= torch.tensor(df_r['revenue_resid_mean'].values, dtype=torch.float32, device=device)
    revenue_resid_var = torch.tensor(df_r['revenue_resid_var'].values, dtype=torch.float32, device=device)
    return (transition_model, revenue_model, initial_orders_values, initial_orders_probs,
            orders_resid_mean, orders_resid_var, drivers_resid_mean, drivers_resid_var,
            revenue_resid_mean, revenue_resid_var)

@torch.no_grad()
def estimate_true_ate_paired_mc_strict(trans_model, rev_model, initial_orders_values, initial_orders_probs, resid_data,round_reward=True, num_trajs=10000, T=20, per_step=True, max_chunk_size=5000):
    device_local = next(trans_model.parameters()).device
    N = int(num_trajs)
    probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)
    gen = torch.Generator(device=device_local)
    gen.manual_seed(2025)
    idx = torch.multinomial(probs, N, replacement=True, generator=gen)
    orders0_all = initial_orders_values[idx].unsqueeze(1).float().to(device_local)
    drivers0_all = torch.full((N,1), 50.0, device=device_local)
    def pad_repeat(seq):
        s = seq.to(device_local).float()
        if s.shape[0] >= T: return s[:T]
        last = s[-1].unsqueeze(0).expand(T - s.shape[0])
        return torch.cat([s, last], dim=0)
    orders_mean_seq = pad_repeat(resid_data['orders_mean'])
    drivers_mean_seq= pad_repeat(resid_data['drivers_mean'])
    orders_var_seq  = pad_repeat(resid_data['orders_var'])
    drivers_var_seq = pad_repeat(resid_data['drivers_var'])
    revenue_mean_seq= pad_repeat(resid_data['revenue_mean'])
    revenue_var_seq = pad_repeat(resid_data['revenue_var'])
    state_std_per_t = torch.sqrt(torch.stack([orders_var_seq, drivers_var_seq], dim=1).clamp_min(1e-12))
    reward_std_per_t= torch.sqrt(revenue_var_seq.clamp_min(1e-12))
    diffs_chunks = []
    chunk = int(max(1, min(max_chunk_size, N)))
    for start in range(0, N, chunk):
        cur = min(chunk, N - start)
        orders0 = orders0_all[start:start+cur]
        drivers0 = drivers0_all[start:start+cur]
        noise_s_all = torch.randn((T, cur, 2), device=device_local, generator=gen) * state_std_per_t.view(T,1,2)
        noise_r_all = torch.randn((T, cur), device=device_local, generator=gen) * reward_std_per_t.view(T,1)
        st_p = torch.cat([orders0.clone(), drivers0.clone()], dim=1)
        st_m = torch.cat([orders0.clone(), drivers0.clone()], dim=1)
        cum_p = torch.zeros((cur,), device=device_local)
        cum_m = torch.zeros((cur,), device=device_local)
        for t in range(T):
            a_p = torch.full((cur,1), 0.0, device=device_local)
            a_m = torch.full((cur,1), 1.0, device=device_local)
            t_e = torch.full((cur,1), float(t), device=device_local)
            Xp = torch.cat([st_p, a_p, t_e], dim=1)
            Xm = torch.cat([st_m, a_m, t_e], dim=1)
            nm_p = trans_model(Xp); nm_m = trans_model(Xm)
            rm_p = rev_model(Xp).squeeze(1); rm_m = rev_model(Xm).squeeze(1)
            noise_s_t = noise_s_all[t]; noise_r_t = noise_r_all[t]
            resid_mean_state_t = torch.stack([orders_mean_seq[t], drivers_mean_seq[t]], dim=0).unsqueeze(0)
            resid_mean_r_t = revenue_mean_seq[t]
            st_p = (nm_p + resid_mean_state_t + noise_s_t).round().clamp(min=0).float()
            st_m = (nm_m + resid_mean_state_t + noise_s_t).round().clamp(min=0).float()
            if round_reward:
                r_p = (rm_p + resid_mean_r_t + noise_r_t).round().clamp(min=0).view(-1).float()
                r_m = (rm_m + resid_mean_r_t + noise_r_t).round().clamp(min=0).view(-1).float()
            else:
                r_p = (rm_p + resid_mean_r_t + noise_r_t).view(-1).float()
                r_m = (rm_m + resid_mean_r_t + noise_r_t).view(-1).float()
            cum_p += r_p
            cum_m += r_m
        diffs_chunks.append(cum_p - cum_m)
    diffs = torch.cat(diffs_chunks, dim=0)
    ate_total = float(diffs.mean().item())
    se_total = diffs.std(unbiased=True).item() / math.sqrt(max(1, N))
    if per_step:
        diffs_per = diffs / float(T)
        ate_per = float(diffs_per.mean().item())
        se_per = diffs_per.std(unbiased=True).item() / math.sqrt(max(1, N))
        return ate_per, diffs_per, se_per
    return ate_total, diffs, se_total
def Q_eta_est_poly_tensor_batch(data, treatment, device=None):
    import torch, logging

    required_keys = ['A','revenue','orders','drivers','ordersNext','driversNext']
    for k in required_keys:
        if k not in data:
            raise KeyError(f"Q_eta_est_poly_tensor_batch: missing key '{k}' in data")

    tensors = {}
    for k, v in data.items():
        if not torch.is_tensor(v):
            tensors[k] = torch.as_tensor(v)
        else:
            tensors[k] = v

    if device is None:
        device = next(iter(tensors.values())).device if len(tensors) > 0 else torch.device('cpu')

    for k in list(tensors.keys()):
        tensors[k] = tensors[k].to(device=device, dtype=torch.float32)

    A_raw = tensors['A']
    if A_raw.dim() != 2:
        raise ValueError(f"Q_eta_est_poly_tensor_batch: A must be 2D, got dim={A_raw.dim()}")

    A = tensors['A']
    B, N = A.shape

    mask = (A == treatment)
    mask_f = mask.float().unsqueeze(-1)
    revenue = tensors['revenue'].unsqueeze(-1)
    orders = tensors['orders'].unsqueeze(-1)
    drivers = tensors['drivers'].unsqueeze(-1)
    ordersNext = tensors['ordersNext'].unsqueeze(-1)
    driversNext = tensors['driversNext'].unsqueeze(-1)

    S = torch.cat([orders, drivers], dim=-1)
    Next_S = torch.cat([ordersNext, driversNext], dim=-1)
    phi_S = torch.cat([S], dim=-1)
    phi_next_S = torch.cat([Next_S], dim=-1)

    mask_sum_raw = mask.float().sum(dim=1)
    mask_sum = mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)

    revenue_mean = (revenue * mask_f).sum(dim=1, keepdim=True) / mask_sum
    revenue_c = (revenue - revenue_mean) * mask_f

    phi_mean = (phi_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_S_c = (phi_S - phi_mean) * mask_f
    phi_next_mean = (phi_next_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_next_S_c = (phi_next_S - phi_next_mean) * mask_f

    diff_phi_S_c = phi_S_c - phi_next_S_c

    lhs = torch.matmul(diff_phi_S_c.transpose(1,2), phi_S_c)
    rhs = torch.matmul(phi_S_c.transpose(1,2), revenue_c)
    P = lhs.shape[1]
    regularization = torch.eye(P, device=device).unsqueeze(0)
    try:
        beta_a = torch.linalg.solve(lhs + 3*regularization, rhs)
    except RuntimeError:
        beta_list = []
        for i in range(B):
            Li = lhs[i]
            ri = rhs[i]
            try:
                bi = torch.linalg.solve(Li + 1e-6 * torch.eye(P, device=device), ri)
            except RuntimeError:
                bi = torch.linalg.lstsq(Li, ri).solution
            beta_list.append(bi)
        beta_a = torch.cat(beta_list, dim=0)

    Q_diff_vec = torch.matmul(diff_phi_S_c, beta_a).squeeze(-1)
    numer = ((revenue.squeeze(-1) - Q_diff_vec) * mask.float()).sum(dim=1)
    denom = mask_sum.squeeze(-1).squeeze(-1)
    zero_mask = (mask_sum_raw == 0.0)
    denom_safe = denom.clone()
    denom_safe[zero_mask] = 1.0

    eta_est = numer / denom_safe
    if zero_mask.any():
        eta_est = eta_est.clone()
        eta_est[zero_mask] = 0.0
    TD_error = (revenue.squeeze(-1) - Q_diff_vec - eta_est.unsqueeze(1)) * mask.float()
    return eta_est, TD_error, beta_a


@torch.no_grad()
def run_one_batch_random_sims_and_estimate(
    trans_model, revenue_model,
    initial_orders_values, initial_orders_probs,
    resid_data,
    Qeta_fn,
    NUM_ENVS=128,
    times=20,
    days=30,
    num_sims=100,
    seed=2025,
    device=None
):
    if device is None:
        device = next(trans_model.parameters()).device

    trans_model.eval()
    revenue_model.eval()

    initial_orders_values = initial_orders_values.to(device)
    initial_orders_probs  = initial_orders_probs.to(device)

    def pad_repeat(seq, L):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    N = int(days * times)
    E_total = int(num_sims)
    if E_total <= 0:
        raise ValueError("NUM_ENVS * num_sims must be > 0")

    gen = torch.Generator(device=device)
    gen.manual_seed(int(seed))

    probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)
    idx = torch.multinomial(probs, E_total, replacement=True, generator=gen)
    orders0 = initial_orders_values[idx].unsqueeze(1).float().to(device)
    drivers0 = torch.full((E_total,1), 50.0, device=device)
    state = torch.cat([orders0, drivers0], dim=1)

    orders_hist      = torch.zeros((E_total, N), device=device)
    drivers_hist     = torch.zeros((E_total, N), device=device)
    actions_hist     = torch.zeros((E_total, N), device=device)
    revenue_hist     = torch.zeros((E_total, N), device=device)
    ordersNext_hist  = torch.zeros((E_total, N), device=device)
    driversNext_hist = torch.zeros((E_total, N), device=device)

    state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
    reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

    for s in range(N):
        t_in_day = s % times
        a_val = torch.randint(0, 2, (E_total,), device=device, dtype=torch.long)
        actions_for_nn = a_val.unsqueeze(1).float()
        t_tensor = torch.full((E_total,1), float(t_in_day), device=device)
        X = torch.cat([state, actions_for_nn, t_tensor], dim=1)
        next_mean = trans_model(X)
        rev_mean = revenue_model(X).squeeze(1)
        std_state = state_std_per_time[t_in_day].view(1,2)
        noise_state = torch.randn((E_total,2), device=device, generator=gen) * std_state.expand(E_total,2)
        std_reward = float(reward_std_per_time[t_in_day].item())
        noise_r = torch.randn((E_total,), device=device, generator=gen) * std_reward
        resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)
        resid_mean_r = float(revenue_mean_t[t_in_day].item())
        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()
        orders_hist[:, s]      = state[:, 0].view(-1)
        drivers_hist[:, s]     = state[:, 1].view(-1)
        actions_hist[:, s]     = a_val.float().view(-1)
        revenue_hist[:, s]     = reward.view(-1)
        ordersNext_hist[:, s]  = next_state[:, 0].view(-1)
        driversNext_hist[:, s] = next_state[:, 1].view(-1)
        state = next_state

    data_for_ope = {
        'A': actions_hist, 'revenue': revenue_hist,
        'orders': orders_hist, 'drivers': drivers_hist,
        'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
    }

    eta_t0, _, _ = Qeta_fn(data_for_ope, treatment=0, device=device)
    eta_t1, _, _ = Qeta_fn(data_for_ope, treatment=1, device=device)
    eta_diff = (eta_t0 - eta_t1)
    ate_per_sim = eta_diff.mean().cpu().numpy()

    ate_mean = float(ate_per_sim.mean()) if ate_per_sim.size > 0 else 0.0
    ate_se = float(ate_per_sim.std(ddof=1) / math.sqrt(max(1, ate_per_sim.size))) if ate_per_sim.size > 1 else 0.0

    return eta_diff , ate_mean, ate_se


def get_divisors(n):
    small = []
    large = []
    i = 1
    while i * i <= n:
        if n % i == 0:
            small.append(i)
            if i != n // i:
                large.append(n // i)
        i += 1
    return sorted(small + large[::-1])

@torch.no_grad()
def run_batch_generate_and_estimate(
    trans_model, revenue_model,
    initial_orders_values, initial_orders_probs,
    resid_data,
    Qeta_fn,
    num_sims=100,
    times=20,
    days=30,
    f=1,
    seed=2025,
    ATE_true_hat=0.0,
    device=None,
    save_csv_path=None
):
    if device is None:
        device = next(trans_model.parameters()).device

    trans_model.eval(); revenue_model.eval()
    initial_orders_values = initial_orders_values.to(device)
    initial_orders_probs  = initial_orders_probs.to(device)

    N = int(days * times)
    E_total = int(num_sims)

    def pad_repeat(seq, L):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
    reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

    gen_init = torch.Generator(device=device); gen_init.manual_seed(int(seed))
    gen_step = torch.Generator(device=device); gen_step.manual_seed(int(seed + 1))

    probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)
    idx_init = torch.multinomial(probs, E_total, replacement=True, generator=gen_init)
    orders0 = initial_orders_values[idx_init].unsqueeze(1).float().to(device)
    drivers0 = torch.full((E_total,1), 50.0, device=device)
    state = torch.cat([orders0, drivers0], dim=1)

    t_in_day = torch.arange(times, device=device)
    f_eff = max(1, int(f))
    base_block = ((t_in_day // f_eff) % 2).long()
    gen_phase = torch.Generator(device=device)
    gen_phase.manual_seed(int(seed))
    phases = torch.randint(0, 2, (E_total,), generator=gen_phase, device=device).long()
    base_day0 = (base_block.unsqueeze(0).expand(E_total, -1) ^ phases.unsqueeze(1)).long()

    parts = []
    for d in range(int(days)):
        if (d % 2) == 0:
            parts.append(base_day0)
        else:
            parts.append(1 - base_day0)
    actions_big = torch.cat(parts, dim=1).to(dtype=torch.long, device=device)

    if save_csv_path is not None:
        import pandas as _pd
        _pd.DataFrame(actions_big.cpu().numpy()).to_csv(save_csv_path, index=False)

    orders_hist = torch.zeros((E_total, N), device=device)
    drivers_hist = torch.zeros((E_total, N), device=device)
    actions_hist = torch.zeros((E_total, N), device=device)
    revenue_hist = torch.zeros((E_total, N), device=device)
    ordersNext_hist = torch.zeros((E_total, N), device=device)
    driversNext_hist = torch.zeros((E_total, N), device=device)

    for s in range(N):
        t_in_day = s % times
        a_val = actions_big[:, s]
        acts = a_val.unsqueeze(1).float()
        t_tensor = torch.full((E_total,1), float(t_in_day), device=device)
        X = torch.cat([state, acts, t_tensor], dim=1)
        next_mean = trans_model(X)
        rev_mean = revenue_model(X).squeeze(1)
        std_state = state_std_per_time[t_in_day].view(1,2)
        noise_state = torch.randn((E_total,2), device=device, generator=gen_step) * std_state.expand(E_total,2)
        std_reward = float(reward_std_per_time[t_in_day].item())
        noise_r = torch.randn((E_total,), device=device, generator=gen_step) * std_reward
        resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)
        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
        resid_mean_r = float(revenue_mean_t[t_in_day].item())
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()
        orders_hist[:, s] = state[:, 0]
        drivers_hist[:, s] = state[:, 1]
        actions_hist[:, s] = a_val.float()
        revenue_hist[:, s] = reward
        ordersNext_hist[:, s] = next_state[:, 0]
        driversNext_hist[:, s] = next_state[:, 1]
        state = next_state

    data_for_ope = {
        'A': actions_hist, 'revenue': revenue_hist,
        'orders': orders_hist, 'drivers': drivers_hist,
        'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
    }

    out0 = Qeta_fn(data_for_ope, treatment=0, device=device)
    out1 = Qeta_fn(data_for_ope, treatment=1, device=device)
    eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
    eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1

    if not torch.is_tensor(eta_t0) or eta_t0.dim() != 1 or eta_t0.numel() != E_total:
        raise RuntimeError(f"eta_t0 shape unexpected: got {getattr(eta_t0,'shape',None)}, expected ({E_total},)")
    if not torch.is_tensor(eta_t1) or eta_t1.dim() != 1 or eta_t1.numel() != E_total:
        raise RuntimeError(f"eta_t1 shape unexpected: got {getattr(eta_t1,'shape',None)}, expected ({E_total},)")

    eta_diff = (eta_t0 - eta_t1).cpu().numpy()
    ate_list = eta_diff.tolist()
    ate_mean = float(np.mean(eta_diff))
    ate_se = float(np.std(eta_diff, ddof=1) / np.sqrt(max(1, E_total))) if E_total > 1 else 0.0
    mse = float(np.mean((eta_diff - float(ATE_true_hat))**2))
    bias2 = float((ate_mean - float(ATE_true_hat))**2)
    variance = float(np.var(eta_diff, ddof=1)) if E_total > 1 else 0.0

    return {
        'ate_list': ate_list,
        'ate_mean': ate_mean,
        'ate_se': ate_se,
        'mse': mse,
        'bias2': bias2,
        'variance': variance,
        'eta_diff': eta_diff,
        'data_for_ope': data_for_ope
    }


class GenerativeEnvGPU:
    def __init__(self, trans_model, rev_model, num_envs, initial_dist_data, residuals_data):
        self.trans_model = trans_model
        self.rev_model = rev_model
        self.device = device
        self.E = int(num_envs)
        self.init_vals = initial_dist_data['values']
        self.init_probs = initial_dist_data['probs']
        self.orders_resid_mean = residuals_data['orders_mean']
        self.orders_resid_var  = residuals_data['orders_var']
        self.drivers_resid_mean= residuals_data['drivers_mean']
        self.drivers_resid_var = residuals_data['drivers_var']
        self.revenue_resid_mean= residuals_data['revenue_mean']
        self.revenue_resid_var = residuals_data['revenue_var']
        self.gen = torch.Generator(device=device); self.gen.manual_seed(2025)
        self.state = None
    def reset(self):
        probs = self.init_probs / (self.init_probs.sum() + 1e-12)
        idx = torch.multinomial(probs, self.E, replacement=True, generator=self.gen)
        orders = self.init_vals[idx].unsqueeze(1).float()
        drivers = torch.full((self.E,1), 50.0, device=device)
        self.state = torch.cat([orders, drivers], dim=1)
        return self.state.clone()
    @torch.no_grad()
    def step(self, actions_val, t):
        E = self.E
        actions_for_nn = actions_val.unsqueeze(1)
        t_tensor = torch.full((E,1), float(t), device=device)
        X = torch.cat([self.state, actions_for_nn, t_tensor], dim=1)
        next_mean = self.trans_model(X)
        rev_mean = self.rev_model(X).squeeze(1)
        ti = min(t, self.orders_resid_mean.shape[0]-1)
        resid_mean_state = torch.stack([self.orders_resid_mean[ti], self.drivers_resid_mean[ti]], dim=0).unsqueeze(0)
        resid_var_state = torch.stack([self.orders_resid_var[ti], self.drivers_resid_var[ti]], dim=0).unsqueeze(0)
        std_state = torch.sqrt(resid_var_state.clamp_min(1e-9))
        noise_state = torch.randn((E,2), device=device, generator=self.gen) * std_state.expand(E,2)
        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
        ti_r = min(t, self.revenue_resid_mean.shape[0]-1)
        resid_mean_r = self.revenue_resid_mean[ti_r]
        resid_std_r = math.sqrt(float(self.revenue_resid_var[ti_r].clamp_min(1e-9)))
        noise_r = torch.randn((E,), device=device, generator=self.gen) * resid_std_r
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()
        self.state = next_state.clone()
        done = torch.zeros((E,), dtype=torch.bool, device=device)
        return next_state.clone(), reward, done, {}
    
def run_design_experiment(
    make_A_fn: Callable,
    env,
    times: int,
    days: int,
    num_sims: int = 100,
    sim_size: int = 1,
    reps: int = 200,
    seed: int = 2025,
    estimator_list: Optional[List[str]] = None,
    wager_ms: Optional[List[int]] = None,
    ATE_true: Optional[float] = None,
    device: Optional[torch.device] = None,
    save_prefix: Optional[str] = None
) -> Dict:
    """
    Unified wrapper: generate actions using make_A_fn, interact with env, compute Xiong/Wager estimates.
    Returns dict: each estimator -> {'per_rep_per_sim': np.array shape (reps, num_sims),
                                     'mean': float,
                                     'var': float,
                                     'mse': float (if ATE_true given),
                                     'bias': float (if ATE_true)}
    Notes:
     - make_A_fn(rng, N, T) should return matrix shape (N, T) with values 0/1 or -1/+1.
     - Env interaction uses simulate_env_with_action_matrix_binary.
     - estimator_list supports 'xiong' and 'wager'. For 'wager', pass wager_ms list.
     - sim_size groups E trajectories into num_sims groups.
    """

    if estimator_list is None:
        estimator_list = ['xiong']
    if wager_ms is None:
        wager_ms = [2]

    if device is None:
        device = getattr(env, "device", torch.device("cpu"))

    results = {}
    E = num_sims * sim_size
    T_day = times

    perrep_storage = {}
    for est in estimator_list:
        if est == 'xiong':
            perrep_storage['xiong'] = np.full((reps, num_sims), np.nan, dtype=float)
        elif est == 'wager':
            for m in wager_ms:
                perrep_storage[f"wager_m{m}"] = np.full((reps, num_sims), np.nan, dtype=float)
        else:
            raise ValueError(f"Unknown estimator '{est}'")

    for r in range(reps):
        rng = np.random.default_rng(seed + r)
        A_raw = make_A_fn(rng, E, T_day)

        if isinstance(A_raw, torch.Tensor):
            A_np = A_raw.detach().cpu().numpy()
        else:
            A_np = np.asarray(A_raw)
        if set(np.unique(A_np).tolist()) <= {-1, 1}:
            A_np = (A_np == 1).astype(np.int64)
        else:
            A_np = (A_np != 0).astype(np.int64)

        hist = simulate_env_with_action_matrix_binary(env, A_np, times=T_day, days=days, device=device, seed=int(seed + r))
        rewards = hist['revenue_hist']
        actions = hist['actions_hist']

        traj_estimates = {}
        if 'xiong' in estimator_list:
            x_t = estimate_ATE_Xiong_binary_torch(rewards, actions, flatten_days=True, control_minus_treat=True)
            traj_estimates['xiong'] = x_t.detach().cpu().numpy()

        if 'wager' in estimator_list:
            for m in wager_ms:
                w_t = estimate_ATE_wager_binary_torch(rewards, actions, times=T_day, days=days, m=m, control_minus_treat=True)
                traj_estimates[f"wager_m{m}"] = w_t.detach().cpu().numpy()

        for key, arr in traj_estimates.items():
            arr = np.asarray(arr).ravel()
            if arr.size != E:
                raise RuntimeError(f"Estimator {key} returned length {arr.size} but expected {E}")
            per_sim = arr.reshape(num_sims, sim_size).mean(axis=1)
            perrep_storage[key][r, :] = per_sim

        del hist, rewards, actions

    summary = {}
    for key, mat in perrep_storage.items():
        mat_flat = mat.reshape(-1)
        mean = float(np.nanmean(mat_flat)) if mat_flat.size > 0 else float('nan')
        var = float(np.nanvar(mat_flat, ddof=1)) if mat_flat.size > 1 else 0.0
        res = {'per_rep_per_sim': mat, 'mean': mean, 'var': var}
        if ATE_true is not None:
            mse = float(np.nanmean((mat_flat - float(ATE_true))**2))
            bias = float(mean - float(ATE_true))
            res.update({'mse': mse, 'bias': bias})
        summary[key] = res

        if save_prefix is not None:
            df = pd.DataFrame(mat)
            csv_name = f"{save_prefix}_{key}.csv"
            df.to_csv(csv_name, index_label="rep")
    return summary


@torch.no_grad()
def simulate_env_with_action_matrix_binary(
    env,
    actions_big,
    times: int,
    days: int,
    device: Optional[torch.device] = None,
    seed: Optional[int] = None
) -> Dict[str, torch.Tensor]:
    """
    Feed action matrix into GenerativeEnvGPU and return history.
    Returns dict: orders_hist, drivers_hist, actions_hist, revenue_hist, ordersNext_hist, driversNext_hist
    Shapes: each (E, N), where N = times * days
    """

    if device is None:
        device = getattr(env, "device", None)
        if device is None:
            if torch.is_tensor(actions_big):
                device = actions_big.device
            else:
                device = torch.device("cpu")

    if not torch.is_tensor(actions_big):
        A = torch.as_tensor(actions_big, device=device)
    else:
        A = actions_big.to(device=device)

    if A.dim() == 3:
        E, d, t = A.shape
        assert d == days and t == times, f"actions_big shape mismatch"
        A = A.view(E, d * t)
    elif A.dim() == 2:
        E, N = A.shape
        assert N == times * days, f"actions_big length mismatch"
    else:
        raise ValueError("actions_big must be 2D or 3D")

    unique_vals = torch.unique(A)
    if set(unique_vals.cpu().numpy().tolist()) <= {-1, 1}:
        A = (A == 1).to(torch.int64)
    else:
        A = (A != 0).to(torch.int64)
    if hasattr(env, "E") and env.E != E:
        env = GenerativeEnvGPU(
            trans_model=env.trans_model,
            rev_model=env.rev_model,
            num_envs=E,
            initial_dist_data={'values': env.init_vals, 'probs': env.init_probs},
            residuals_data={
                'orders_mean':  env.orders_resid_mean,
                'orders_var':   env.orders_resid_var,
                'drivers_mean': env.drivers_resid_mean,
                'drivers_var':  env.drivers_resid_var,
                'revenue_mean': env.revenue_resid_mean,
                'revenue_var':  env.revenue_resid_var,
            }
        )

    if seed is not None:
        try:
            if hasattr(env, "gen") and isinstance(getattr(env, "gen"), torch.Generator):
                env.gen.manual_seed(int(seed))
            elif hasattr(env, "manual_seed"):
                env.manual_seed(int(seed))
        except Exception:
            pass

    state = env.reset()
    if not torch.is_tensor(state):
        state = torch.as_tensor(state, device=device, dtype=torch.float32)
    else:
        state = state.to(device=device)

    _, N = A.shape
    orders_hist = torch.zeros((E, N), device=device)
    drivers_hist = torch.zeros((E, N), device=device)
    actions_hist = torch.zeros((E, N), device=device)
    revenue_hist = torch.zeros((E, N), device=device)
    ordersNext_hist = torch.zeros((E, N), device=device)
    driversNext_hist = torch.zeros((E, N), device=device)

    for s in range(N):
        t_in_day = s % times
        a_col = A[:, s].to(device=device)
        a_for_env = a_col.to(torch.float32)
        next_state, reward, done, info = env.step(a_for_env, t_in_day)

        if not torch.is_tensor(next_state):
            next_state = torch.as_tensor(next_state, device=device, dtype=torch.float32)
        else:
            next_state = next_state.to(device=device, dtype=torch.float32)
        if not torch.is_tensor(reward):
            reward = torch.as_tensor(reward, device=device, dtype=torch.float32)
        else:
            reward = reward.to(device=device, dtype=torch.float32)

        orders_hist[:, s] = state[:, 0]
        drivers_hist[:, s] = state[:, 1]
        actions_hist[:, s] = a_for_env
        revenue_hist[:, s] = reward.view(-1)
        ordersNext_hist[:, s] = next_state[:, 0]
        driversNext_hist[:, s] = next_state[:, 1]

        state = next_state

    return {
        'orders_hist': orders_hist,
        'drivers_hist': drivers_hist,
        'actions_hist': actions_hist,
        'revenue_hist': revenue_hist,
        'ordersNext_hist': ordersNext_hist,
        'driversNext_hist': driversNext_hist
    }


@torch.no_grad()
def estimate_ATE_Xiong_binary_torch(
    rewards_hist: torch.Tensor,
    actions_hist: torch.Tensor,
    flatten_days: bool = True,
    control_minus_treat: bool = True
) -> torch.Tensor:
    """
    Returns per-trajectory vector (E,)
    Default: control - treat = mean(Y|A==0) - mean(Y|A==1)
    Supports (E,N) or (E,days,times) if flatten_days=True.
    If a trajectory has zero samples for a group, that group's mean is set to 0.
    """
    if rewards_hist.dim() == 3 and flatten_days:
        E, days, times = rewards_hist.shape
        rewards = rewards_hist.view(E, days * times)
    else:
        rewards = rewards_hist
    if actions_hist.dim() == 3 and flatten_days:
        actions = actions_hist.view(actions_hist.shape[0], -1)
    else:
        actions = actions_hist

    if rewards.dim() != 2 or actions.dim() != 2:
        raise ValueError("rewards_hist/actions_hist must be 2D (E,N) or 3D (E,days,times) with flatten_days=True")

    device = rewards.device
    rewards = rewards.to(device=device, dtype=torch.float32)
    actions = actions.to(device=device)

    mask_t = (actions == 1)
    mask_c = (actions == 0)

    cnt_t = mask_t.sum(dim=1)
    cnt_c = mask_c.sum(dim=1)

    sum_t = (rewards * mask_t.to(rewards.dtype)).sum(dim=1)
    sum_c = (rewards * mask_c.to(rewards.dtype)).sum(dim=1)

    mean_t = torch.zeros_like(sum_t)
    mean_c = torch.zeros_like(sum_c)
    nz_t = (cnt_t > 0)
    nz_c = (cnt_c > 0)
    if nz_t.any():
        mean_t[nz_t] = sum_t[nz_t] / cnt_t[nz_t].to(sum_t.dtype)
    if nz_c.any():
        mean_c[nz_c] = sum_c[nz_c] / cnt_c[nz_c].to(sum_c.dtype)

    return (mean_c - mean_t) if control_minus_treat else (mean_t - mean_c)


@torch.no_grad()
def estimate_ATE_wager_binary_torch(
    rewards_hist: torch.Tensor,
    actions_hist: torch.Tensor,
    times: int,
    days: int,
    m: int,
    control_minus_treat: bool = True
) -> torch.Tensor:
    """
    Wager estimator:
      - Split each day into segments of length m.
      - Drop the first step of each segment, sum the rest.
      - Group by the first action (0=control, 1=treat).
      - Return (control_avg - treat_avg)/(m-1) if control_minus_treat=True.
    Inputs:
      rewards_hist/actions_hist: (E,N) or (E,days,times)
    """
    if m <= 1:
        raise ValueError("Wager requires m > 1")

    if rewards_hist.dim() == 2:
        E, N = rewards_hist.shape
        assert N == times * days, f"times*days ({times*days}) != N ({N})"
        R = rewards_hist.view(E, days, times)
    else:
        R = rewards_hist
        E, d_, t_ = R.shape
        assert d_ == days and t_ == times

    if actions_hist.dim() == 2:
        A = actions_hist.view(E, days, times)
    else:
        A = actions_hist

    device = R.device
    R = R.to(device=device, dtype=torch.float32)
    A = A.to(device=device)

    seg_starts = list(range(0, times, m))
    n_segs = len(seg_starts)

    seg_sums = torch.zeros((E, days, n_segs), dtype=R.dtype, device=device)
    seg_start_action = torch.zeros((E, days, n_segs), dtype=A.dtype, device=device)

    for i, start in enumerate(seg_starts):
        end = min(start + m, times)
        if start + 1 < end:
            seg_sums[:, :, i] = R[:, :, (start + 1):end].sum(dim=2)
        else:
            seg_sums[:, :, i] = 0.0
        seg_start_action[:, :, i] = A[:, :, start]

    seg_sums_flat = seg_sums.view(E, -1)
    seg_actions_flat = seg_start_action.view(E, -1)

    mask_t = (seg_actions_flat == 1)
    mask_c = (seg_actions_flat == 0)

    cnt_t = mask_t.sum(dim=1).float()
    cnt_c = mask_c.sum(dim=1).float()

    sum_t = (seg_sums_flat * mask_t.to(seg_sums_flat.dtype)).sum(dim=1)
    sum_c = (seg_sums_flat * mask_c.to(seg_sums_flat.dtype)).sum(dim=1)

    avg_t = torch.zeros_like(sum_t)
    avg_c = torch.zeros_like(sum_c)
    nz_t = (cnt_t > 0)
    nz_c = (cnt_c > 0)
    if nz_t.any():
        avg_t[nz_t] = sum_t[nz_t] / cnt_t[nz_t]
    if nz_c.any():
        avg_c[nz_c] = sum_c[nz_c] / cnt_c[nz_c]

    denom = float(max(1, m - 1))
    return (avg_c - avg_t) / denom if control_minus_treat else (avg_t - avg_c) / denom


def aggregate_per_sim_mean(est_vector: np.ndarray, sim_size: int) -> np.ndarray:
    """
    Aggregate per-trajectory estimates into per-sim means.
    Input: est_vector (E,), sim_size divides E.
    Output: (num_sims,)
    """
    arr = np.asarray(est_vector)
    E = arr.shape[0]
    if E % sim_size != 0:
        raise ValueError("E must be divisible by sim_size")
    num_sims = E // sim_size
    return arr.reshape(num_sims, sim_size).mean(axis=1)


def run_design_experiment(
    make_A_fn: Callable,
    env,
    times: int,
    days: int,
    num_sims: int = 100,
    sim_size: int = 1,
    reps: int = 200,
    seed: int = 2025,
    estimator_list: Optional[List[str]] = None,
    wager_ms: Optional[List[int]] = None,
    ATE_true: Optional[float] = None,
    device: Optional[torch.device] = None,
    save_prefix: Optional[str] = None
) -> Dict:
    """
    Run experiment:
      - Generate actions
      - Interact with env
      - Compute Xiong/Wager estimates
      - Aggregate per-sim
      - Repeat reps times
    Returns dict with per_rep_per_sim, mean, var, mse, bias (if ATE_true given).
    """
    if estimator_list is None:
        estimator_list = ['xiong']
    if wager_ms is None:
        wager_ms = [2]

    if device is None:
        device = getattr(env, "device", torch.device("cpu"))

    results = {}
    E = num_sims * sim_size
    T_day = times

    perrep_storage: Dict[str, np.ndarray] = {}
    for est in estimator_list:
        if est == 'xiong':
            perrep_storage['xiong'] = np.full((reps, num_sims), np.nan, dtype=float)
        elif est == 'wager':
            for m in wager_ms:
                perrep_storage[f"wager_m{m}"] = np.full((reps, num_sims), np.nan, dtype=float)
        else:
            raise ValueError(f"Unknown estimator '{est}'")

    for r in range(reps):
        rng = np.random.default_rng(seed + r)
        A_raw = make_A_fn(rng, E, T_day)
        if isinstance(A_raw, torch.Tensor):
            A_np = A_raw.detach().cpu().numpy()
        else:
            A_np = np.asarray(A_raw)
        if set(np.unique(A_np).tolist()) <= {-1, 1}:
            A_np = (A_np == 1).astype(np.int64)
        else:
            A_np = (A_np != 0).astype(np.int64)
        if A_np.ndim == 2 and A_np.shape[1] == T_day:
            if days > 1:
                A_np = np.tile(A_np, reps=days).astype(np.int64)

        hist = simulate_env_with_action_matrix_binary(env, A_np, times=T_day, days=days, device=device, seed=int(seed + r))
        rewards = hist['revenue_hist']
        actions = hist['actions_hist']

        traj_estimates = {}
        if 'xiong' in estimator_list:
            x_t = estimate_ATE_Xiong_binary_torch(rewards, actions, flatten_days=True, control_minus_treat=True)
            traj_estimates['xiong'] = x_t.detach().cpu().numpy()
        if 'wager' in estimator_list:
            for m in wager_ms:
                w_t = estimate_ATE_wager_binary_torch(rewards, actions, times=T_day, days=days, m=m, control_minus_treat=True)
                traj_estimates[f"wager_m{m}"] = w_t.detach().cpu().numpy()

        for key, arr in traj_estimates.items():
            arr = np.asarray(arr).ravel()
            if arr.size != E:
                raise RuntimeError(f"Estimator {key} returned {arr.size} but expected {E}")
            per_sim = arr.reshape(num_sims, sim_size).mean(axis=1)
            perrep_storage[key][r, :] = per_sim

        del hist, rewards, actions

    summary = {}
    for key, mat in perrep_storage.items():
        mat_flat = mat.reshape(-1)
        mean = float(np.nanmean(mat_flat)) if mat_flat.size > 0 else float('nan')
        var = float(np.nanvar(mat_flat, ddof=1)) if mat_flat.size > 1 else 0.0
        res = {'per_rep_per_sim': mat, 'mean': mean, 'var': var}
        if ATE_true is not None:
            mse = float(np.nanmean((mat_flat - float(ATE_true))**2))
            bias = float(mean - float(ATE_true))
            res.update({'mse': mse, 'bias': bias})
        summary[key] = res
        if save_prefix is not None:
            df = pd.DataFrame(mat)
            csv_name = f"{save_prefix}_{key}.csv"
            df.to_csv(csv_name, index_label="rep")
    return summary


# ================= Bojinov + IS-HT (control - treat) =================
import numpy as np
from typing import Optional
import torch


# ------------------ Bojinov day-level design (±1 encoding, outputs 0/1) ------------------

import numpy as np
from typing import Optional, Tuple, Dict

# 1) Generate Bojinov day-level action matrix (0/1)
def actions_dgp_bojinov(taus: int, ndays: int, ti: int, rng: Optional[np.random.Generator] = None) -> np.ndarray:
    """
    Returns (taus, ndays) action matrix (each column = one day pattern), actions in {0,1}.
    - Segment length = ti
    - Each segment sampled with independent coin flip
    - Each column has 50% chance to flip
    - Odd columns are flipped again to reduce bias
    """
    if rng is None:
        rng = np.random.default_rng()

    base = []
    while len(base) < taus:
        val = 1 if rng.random() < 0.5 else 0
        take = min(ti, taus - len(base))
        base.extend([val] * take)
    base = np.array(base[:taus], dtype=int)

    A = np.empty((taus, ndays), dtype=int)
    for d in range(ndays):
        col = base.copy() if rng.random() < 0.5 else 1 - base
        if (d % 2) == 1:
            col = 1 - col
        A[:, d] = col
    return A


# 2) Factory: returns make_A_fn(rng, N, T) -> (N, T) with 0/1 encoding
def make_bojinov_design(ti: int):
    """
    Returns function that generates (N, T) action matrix with 0/1 encoding.
    Each call generates base_day (length T), then flips per-individual with 50% chance,
    plus an extra flip for odd individuals.
    """
    def _fn(rng: np.random.Generator, N: int, T: int):
        base_day = actions_dgp_bojinov(taus=T, ndays=1, ti=ti, rng=rng)[:, 0]
        A = np.empty((N, T), dtype=int)
        for i in range(N):
            seq = base_day.copy() if rng.random() < 0.5 else 1 - base_day
            if i % 2 == 1:
                seq = 1 - seq
            A[i, :] = seq
        return A
    return _fn


# 3) Simulation: generate dataset with given design
def simulate_dataset_with_design(dgp, A_mat: np.ndarray, rng: Optional[np.random.Generator] = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Given dgp with simulate_day(x0, a_seq) method (a_seq in 0/1) and A_mat (N, T),
    returns X_all (N, T+1, p), Y_all (N, T).
    """
    if rng is None:
        rng = np.random.default_rng()
    N, T = A_mat.shape
    X_all = np.zeros((N, T + 1, 2))  # assume p=2
    Y_all = np.zeros((N, T))
    for i in range(N):
        x0 = dgp.sample_initial_states(1, rng=rng)[0]
        X, Y = dgp.simulate_day(x0, A_mat[i], rng=rng)
        X_all[i] = X
        Y_all[i] = Y
    return X_all, Y_all


# 4) IS-HT estimator (0/1, returns control - treat)
def estimate_ATE_ISHT_binary_direct(Y_mat: np.ndarray, Act_mat: np.ndarray, TI_loop: int) -> float:
    Y = np.asarray(Y_mat, dtype=float)
    A = np.asarray(Act_mat).astype(int)
    if Y.shape != A.shape:
        raise ValueError("Y_mat and Act_mat must have same shape (taus, ndays)")

    taus, ndays = Y.shape
    col_vals = np.full((ndays,), np.nan, dtype=float)

    for col in range(ndays):
        a_col = A[:, col]
        y_col = Y[:, col]
        vec = np.zeros((taus,), dtype=float)
        for t in range(taus):
            if t < TI_loop:
                vec[t] = 4.0 * a_col[t] * y_col[t] - 2.0 * y_col[t]
            else:
                win = a_col[(t - TI_loop):(t + 1)]
                if np.all(win == 1):
                    vec[t] = 4.0 * y_col[t]
                elif np.all(win == 0):
                    vec[t] = -4.0 * y_col[t]
                else:
                    vec[t] = 0.0
        col_vals[col] = np.nanmean(vec)

    est_pm = float(np.nanmean(col_vals))  # ±1 convention
    return -est_pm  # return control - treat


# 5) Evaluation wrapper: repeat experiment multiple times
def evaluate_bojinov_isht(
    dgp,
    ATE_true: float,
    ti: int,
    N: int = 100,
    reps: int = 200,
    seed: int = 2025,
    TI_loop: Optional[int] = None
) -> Dict:
    """
    Evaluate Bojinov + IS-HT.
    Returns dict with mse, mean, var, bias, ests.
    """
    if TI_loop is None:
        TI_loop = ti
    make_A_fn = make_bojinov_design(ti)
    ests = np.zeros((reps,), dtype=float)

    for r in range(reps):
        rng = np.random.default_rng(seed + r)
        A = make_A_fn(rng, N, dgp.T)          
        X_all, Y_all = simulate_dataset_with_design(dgp, A, rng=rng)
        Y_mat = Y_all.T
        Act_mat = A.T
        est = estimate_ATE_ISHT_binary_direct(Y_mat, Act_mat, TI_loop=TI_loop)
        ests[r] = est

    mean = ests.mean()
    var = ests.var(ddof=1)
    bias = mean - ATE_true
    mse = np.mean((ests - ATE_true)**2)
    return {"mse": mse, "mean": mean, "var": var, "bias": bias, "ests": ests}


# ---------------------------- Burn-in design for NMDP ----------------------------
import torch
import math
from typing import Callable, Dict, Any, Optional, Tuple

@torch.no_grad()
def generate_burnin_all0_all1_env(
    trans_model: torch.nn.Module,
    revenue_model: torch.nn.Module,
    initial_orders_values: torch.Tensor,
    initial_orders_probs: torch.Tensor,
    resid_data: dict,
    m0_pos: int,
    m0_neg: int,
    times: int,
    batch_size: int = 64,
    seed: int = 2025,
    device: Optional[torch.device] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generate burn-in data: first m0_pos trajectories with all actions=1, 
    then m0_neg trajectories with all actions=0.
    Returns:
      X_all: (N, times+1, 2) numpy
      Y_all: (N, times) numpy
      A_mat: (N, times) numpy int (0/1)
    """
    if device is None:
        device = next(trans_model.parameters()).device
    trans_model.eval(); revenue_model.eval()

    N = int(m0_pos + m0_neg)
    T = int(times)
    p = 2

    X_all = np.zeros((N, T + 1, p), dtype=float)
    Y_all = np.zeros((N, T), dtype=float)
    A_mat = np.zeros((N, T), dtype=np.int32)

    def _run_const_batch(cur: int, a_val: int, write_start: int, batch_seed_offset: int):
        env = GenerativeEnvGPU(
            trans_model=trans_model,
            rev_model=revenue_model,
            num_envs=cur,
            initial_dist_data={'values': initial_orders_values, 'probs': initial_orders_probs},
            residuals_data=resid_data
        )
        env.gen = torch.Generator(device=device); env.gen.manual_seed(int(seed + batch_seed_offset))
        st0 = env.reset().to(device=device).float()
        states = torch.zeros((cur, T + 1, p), device=device, dtype=torch.float32)
        rewards = torch.zeros((cur, T), device=device, dtype=torch.float32)
        states[:, 0, :] = st0
        for t in range(T):
            acts = torch.full((cur,), float(a_val), device=device)
            next_state, reward, done, _ = env.step(acts, t)
            rewards[:, t] = reward.view(-1).float()
            states[:, t+1, :] = next_state
        X_all[write_start:write_start+cur, :, :] = states.detach().cpu().numpy()
        Y_all[write_start:write_start+cur, :] = rewards.detach().cpu().numpy()
        A_mat[write_start:write_start+cur, :] = (np.ones((cur, T), dtype=np.int32) * a_val)

    ptr = 0; batch_idx = 0
    rem = int(m0_pos)
    while rem > 0:
        cur = min(batch_size, rem)
        _run_const_batch(cur, 1, ptr, batch_idx)
        ptr += cur; rem -= cur; batch_idx += 1
    rem = int(m0_neg)
    while rem > 0:
        cur = min(batch_size, rem)
        _run_const_batch(cur, 0, ptr, batch_idx)
        ptr += cur; rem -= cur; batch_idx += 1

    assert ptr == N
    return X_all, Y_all, A_mat


# ----------------------------
# 2) Prepare burn-in arrays: O1, G, A1
# ----------------------------
def prepare_burnin_from_arrays_np(X_all: np.ndarray, Y_all: np.ndarray, A_mat: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Equivalent to prepare_burnin_from_arrays but returns numpy.
    O1 = X_all[:,0,:]
    G  = sum_t Y_t (total reward per day)
    A1 = A_mat[:,0] (0/1)
    """
    O1 = X_all[:, 0, :]            # (N, p)
    G  = Y_all.sum(axis=1)         # (N,)
    A1 = A_mat[:, 0].astype(int)   # (N,)
    uniq = np.unique(A1)
    if set(uniq.tolist()) <= {-1, 1}:
        A1 = ((A1 + 1) // 2).astype(int)
    return O1, G, A1


# ----------------------------
# 3) Fit sigma functions (NMDP, two-step regression, 0/1)
# ----------------------------
def _add_intercept(X: np.ndarray) -> np.ndarray:
    X = np.atleast_2d(X)
    return np.column_stack([np.ones(X.shape[0]), X])

def _ols_fit(X: np.ndarray, y: np.ndarray, ridge: float = 1e-8) -> np.ndarray:
    XT_X = X.T @ X
    d = XT_X.shape[0]
    return np.linalg.solve(XT_X + ridge * np.eye(d), X.T @ y)

def fit_sigma_functions_NMDP_linear_binary(
    O1: np.ndarray,  
    G: np.ndarray,   
    A1: np.ndarray,  
    use_logvar: bool = False,
    min_var: float = 1e-12
) -> Dict[str, Any]:
    """
    Two-step regression: fit mu_a(O1) and var_a(O1) (or logvar).
    Returns pred_mu / pred_sigma (std).
    """
    O1 = np.atleast_2d(O1)
    G = np.asarray(G).reshape(-1)
    A1_in = np.asarray(A1).reshape(-1)
    uniq = np.unique(A1_in)
    if set(uniq.tolist()) <= {-1, 1}:
        A1 = ((A1_in + 1) // 2).astype(int)
    else:
        A1 = A1_in.astype(int)
    N, p = O1.shape
    X = _add_intercept(O1)

    res = {'theta_mu': {}, 'theta_var': {}}
    for a in [1, 0]:
        mask = (A1 == a)
        n_a = int(mask.sum())
        if n_a >= X.shape[1] + 1:
            theta_mu = _ols_fit(X[mask], G[mask])
            G_hat = X[mask] @ theta_mu
            resid2 = (G[mask] - G_hat) ** 2
            resid2 = np.clip(resid2, min_var, None)
            if use_logvar:
                yv = np.log(resid2 + 1e-12)
            else:
                yv = resid2
            theta_var = _ols_fit(X[mask], yv)
        else:
            mu_const = float(G[mask].mean()) if n_a > 0 else 0.0
            theta_mu = np.zeros(X.shape[1]); theta_mu[0] = mu_const
            var_const = float(np.var(G[mask] - mu_const, ddof=1)) if n_a > 1 else 1.0
            var_const = max(var_const, min_var)
            theta_var = np.zeros(X.shape[1])
            theta_var[0] = np.log(var_const) if use_logvar else var_const
        res['theta_mu'][a] = theta_mu
        res['theta_var'][a] = theta_var

    def pred_mu(O1_new: np.ndarray, a: int) -> np.ndarray:
        Xn = _add_intercept(np.atleast_2d(O1_new))
        return (Xn @ res['theta_mu'][int(a)]).reshape(-1)

    def pred_sigma(O1_new: np.ndarray, a: int) -> np.ndarray:
        Xn = _add_intercept(np.atleast_2d(O1_new))
        theta_v = res['theta_var'][int(a)]
        if use_logvar:
            v = np.exp(Xn @ theta_v)
        else:
            v = Xn @ theta_v
        v = np.clip(v, min_var, None)
        return np.sqrt(v).reshape(-1)

    res['pred_mu'] = pred_mu
    res['pred_sigma'] = pred_sigma
    return res


# ----------------------------
# 4) Construct NMDP design factory from pred_sigma (0/1, full-day actions)
# ----------------------------
def make_nmdp_design_from_sigma(pred_sigma_fn: Callable[[np.ndarray, int], np.ndarray]) -> Callable:
    """
    Returns make_A_fn(rng, N, T) -> A_mat (N,T) with 0/1 actions.
    """
    def _fn(rng: np.random.Generator, N: int, T: int):
        O1 = rng.normal(size=(N, 2))
        s1 = pred_sigma_fn(O1, 1).reshape(-1)
        s0 = pred_sigma_fn(O1, 0).reshape(-1)
        p1 = s1 / (s1 + s0 + 1e-12)
        p1 = np.clip(p1, 1e-6, 1-1e-6)
        U = rng.random(size=N)
        A1 = (U < p1).astype(int)
        return np.repeat(A1.reshape(-1,1), T, axis=1).astype(int)
    return _fn


# ----------------------------
# 5) Run NMDP design in env and estimate using Q_eta
# ----------------------------
@torch.no_grad()
def run_nmdp_generate_and_estimate(
    trans_model: torch.nn.Module,
    revenue_model: torch.nn.Module,
    initial_orders_values: torch.Tensor,
    initial_orders_probs: torch.Tensor,
    resid_data: Dict[str, torch.Tensor],
    pred_sigma_fn: Callable[[np.ndarray, int], np.ndarray],
    Qeta_fn: Callable[..., Any],
    times: int = 20,
    days: int = 30,
    num_sims: int = 100,
    seed: int = 2025,
    device: Optional[torch.device] = None,
    max_chunk_size: int = 5000
) -> Tuple[np.ndarray, float, float]:
    """
    Returns (eta_diff_array, ate_mean, ate_se).
    eta_diff_array shape = (num_sims,) —— each sim's (eta_t0 - eta_t1).
    """
    if device is None:
        device = next(trans_model.parameters()).device
    trans_model.eval(); revenue_model.eval()

    orders_vals_np = initial_orders_values.detach().cpu().numpy().ravel()
    orders_probs_np = initial_orders_probs.detach().cpu().numpy().ravel()
    rng = np.random.default_rng(int(seed))
    K = len(orders_vals_np)
    if K == 0:
        raise ValueError("initial_orders_values empty")
    idxs = rng.choice(K, size=int(num_sims), p=(orders_probs_np / (orders_probs_np.sum() + 1e-12)), replace=True)
    orders0_np = orders_vals_np[idxs]
    drivers0_np = np.full((int(num_sims),), 50.0)
    O1_np = np.stack([orders0_np, drivers0_np], axis=1)

    s1 = np.asarray(pred_sigma_fn(O1_np, 1)).reshape(-1)
    s0 = np.asarray(pred_sigma_fn(O1_np, 0)).reshape(-1)
    p1 = s1 / (s1 + s0 + 1e-12)
    p1 = np.clip(p1, 1e-6, 1-1e-6)

    U = rng.random(size=int(num_sims))
    A1 = (U < p1).astype(np.int64)
    N = int(days * times)
    actions_big = np.repeat(A1[:, None], N, axis=1).astype(np.int64)

    def pad_repeat(seq: torch.Tensor, L: int):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    E_total = int(num_sims)
    if E_total <= 0:
        raise ValueError("num_sims must be > 0")

    chunk = int(max(1, min(max_chunk_size, E_total)))
    eta_diffs_chunks = []

    for start in range(0, E_total, chunk):
        cur = min(chunk, E_total - start)
        acts_chunk_np = actions_big[start:start+cur, :]
        acts_chunk = torch.from_numpy(acts_chunk_np).to(device=device)

        orders_hist      = torch.zeros((cur, N), device=device)
        drivers_hist     = torch.zeros((cur, N), device=device)
        actions_hist     = torch.zeros((cur, N), device=device)
        revenue_hist     = torch.zeros((cur, N), device=device)
        ordersNext_hist  = torch.zeros((cur, N), device=device)
        driversNext_hist = torch.zeros((cur, N), device=device)

        orders0_chunk = torch.tensor(orders0_np[start:start+cur], device=device).unsqueeze(1).float()
        drivers0_chunk= torch.full((cur,1), 50.0, device=device)
        state = torch.cat([orders0_chunk, drivers0_chunk], dim=1)

        gen = torch.Generator(device=device)
        gen.manual_seed(int(seed + start + 12345))

        state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
        reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

        for s in range(N):
            t_in_day = s % times
            a_val = acts_chunk[:, s].to(dtype=torch.float32)
            acts_for_nn = a_val.unsqueeze(1)
            t_tensor = torch.full((cur,1), float(t_in_day), device=device)

            X = torch.cat([state, acts_for_nn, t_tensor], dim=1)
            next_mean = trans_model(X)
            rev_mean = revenue_model(X).squeeze(1)

            std_state = state_std_per_time[t_in_day].view(1,2)
            noise_state = torch.randn((cur,2), device=device, generator=gen) * std_state.expand(cur,2)
            std_reward = float(reward_std_per_time[t_in_day].item())
            noise_r = torch.randn((cur,), device=device, generator=gen) * std_reward

            resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)
            resid_mean_r = float(revenue_mean_t[t_in_day].item())

            next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
            reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()

            orders_hist[:, s] = state[:, 0]
            drivers_hist[:, s] = state[:, 1]
            actions_hist[:, s] = a_val
            revenue_hist[:, s] = reward
            ordersNext_hist[:, s] = next_state[:, 0]
            driversNext_hist[:, s] = next_state[:, 1]

            state = next_state

        data_for_ope = {
            'A': actions_hist, 'revenue': revenue_hist,
            'orders': orders_hist, 'drivers': drivers_hist,
            'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
        }

        out0 = Qeta_fn(data_for_ope, treatment=0, device=device)
        out1 = Qeta_fn(data_for_ope, treatment=1, device=device)
        eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
        eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1

        eta_t0 = torch.as_tensor(eta_t0, device=device).float() if not torch.is_tensor(eta_t0) else eta_t0.float()
        eta_t1 = torch.as_tensor(eta_t1, device=device).float() if not torch.is_tensor(eta_t1) else eta_t1.float()

        eta_diff_chunk = (eta_t0 - eta_t1).detach().cpu().numpy()
        eta_diffs_chunks.append(eta_diff_chunk)

    eta_diff = np.concatenate(eta_diffs_chunks, axis=0)
    ate_mean = float(np.mean(eta_diff))
    ate_se = float(np.std(eta_diff, ddof=1) / math.sqrt(max(1, eta_diff.size))) if eta_diff.size > 1 else 0.0
    return eta_diff, ate_mean, ate_se

# ---------- Convenience wrapper: burn-in -> fit sigma* -> evaluation ----------
@torch.no_grad()
def run_full_tmdp_pipeline_binary(
    trans_model: torch.nn.Module,
    revenue_model: torch.nn.Module,
    initial_orders_values: torch.Tensor,
    initial_orders_probs: torch.Tensor,
    resid_data: Dict[str, torch.Tensor],
    m0_pos: int,
    m0_neg: int,
    times: int,
    days: int,
    num_sims: int = 200,
    seed: int = 2025,
    batch_size: int = 64,
    device: Optional[torch.device] = None,
    Qeta_fn: Callable[..., Any] = None,
    max_chunk_size: int = 5000
) -> Dict[str, Any]:
    """
    One-shot TMDP pipeline:
      1) Generate burn-in (all-1 / all-0)
      2) Fit sigma* (0/1)
      3) Construct p1 and evaluate

    Returns dict: sigma_pos_star, sigma_neg_star, p1, eta_diff, ate_mean, ate_se
    """
    if device is None:
        device = next(trans_model.parameters()).device

    # 1) Burn-in
    X_bi, Y_bi, A_bi = generate_burnin_all1_all0_env(
        trans_model, revenue_model,
        initial_orders_values, initial_orders_probs,
        resid_data,
        m0_pos=m0_pos, m0_neg=m0_neg,
        times=times,
        device=device,
        seed=seed
    )

    # 2) Fit sigma*
    sig_star = fit_sigma_star_TMDP_binary_from_burnin(X_bi, Y_bi, A_bi)
    sigma_pos_star = sig_star['sigma_pos_star']
    sigma_neg_star = sig_star['sigma_neg_star']
    denom = sigma_pos_star + sigma_neg_star if (sigma_pos_star + sigma_neg_star) > 0 else 1.0
    p1 = sigma_pos_star / denom

    # 3) Evaluate
    if Qeta_fn is None:
        raise ValueError("Qeta_fn must be provided (e.g., Q_eta_est_poly_tensor_batch)")
    eta_diff, ate_mean, ate_se = run_tmdp_generate_and_estimate(
        trans_model, revenue_model,
        initial_orders_values, initial_orders_probs,
        resid_data,
        p1=p1, Qeta_fn=Qeta_fn,
        times=times, days=days, num_sims=num_sims,
        seed=seed, device=device, max_chunk_size=max_chunk_size
    )

    return {
        'sigma_pos_star': sigma_pos_star,
        'sigma_neg_star': sigma_neg_star,
        'p1': p1,
        'eta_diff': eta_diff,
        'ate_mean': ate_mean,
        'ate_se': ate_se
    }
# =================== End of TMDP module =================
###Begin the TRL evaluation

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model, device=device)
        position = torch.arange(0, max_len, dtype=torch.float32, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        L = x.size(1)
        return x + self.pe[:L].unsqueeze(0)

class TransformerDQN(nn.Module):
    def __init__(self, feature_dim=8, d_model=128, nhead=4, num_layers=2, dim_feedforward=256, num_actions=2, dropout=0.1, max_len=8192):
        super().__init__()
        self.in_proj = nn.Linear(feature_dim, d_model).to(device)
        self.posenc = PositionalEncoding(d_model, max_len=max_len)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation="gelu"
        ).to(device)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers).to(device)
        self.head = nn.Sequential(
            nn.LayerNorm(d_model).to(device),
            nn.Linear(d_model, dim_feedforward).to(device),
            nn.GELU(),
            nn.Linear(dim_feedforward, num_actions).to(device)
        )
    def forward(self, x, key_padding_mask=None):
        x = self.in_proj(x)
        x = self.posenc(x)
        L = x.size(1)
        causal_mask = torch.triu(torch.ones(L, L, device=x.device), diagonal=1).bool()
        x = self.encoder(x, mask=causal_mask, src_key_padding_mask=key_padding_mask)
        last = x[:, -1, :]
        return self.head(last)


def map_state_dict_prefix(state_dict, from_prefix, to_prefix):
    mapped = {}
    for k, v in state_dict.items():
        if k.startswith(from_prefix):
            newk = to_prefix + k[len(from_prefix):]
        else:
            newk = k
        mapped[newk] = v
    return mapped


def run_policy_test(
    trans_model, revenue_model, policy,
    initial_orders_values, initial_orders_probs,
    resid_data,
    Qeta_fn,
    num_sims=100,
    times=20,
    days=30,
    seed=2025,
    epsilon=0.0,
    ATE_true_hat=None,
    warmup_days=0,
    decay_base=0.8,
    device=None
):
    if ATE_true_hat is None:
        raise ValueError("run_policy_test: ATE_true_hat is required (pass per-step true ATE).")
    if device is None:
        device = next(trans_model.parameters()).device
    trans_model.eval(); revenue_model.eval()

    def pad_repeat(seq, L):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
    reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

    N = int(days * times)
    E = int(num_sims)
    if E <= 0:
        raise ValueError("num_sims must be > 0")
    gen = torch.Generator(device=device); gen.manual_seed(int(seed))

    probs = initial_orders_probs.to(device) / (initial_orders_probs.sum() + 1e-12)
    idx = torch.multinomial(probs, E, replacement=True, generator=gen)
    orders0 = initial_orders_values.to(device)[idx].unsqueeze(1).float()
    drivers0 = torch.full((E,1), 50.0, device=device)
    state = torch.cat([orders0, drivers0], dim=1)

    FEATURE_DIM = 8
    L_history = N
    H = torch.zeros((E, L_history, FEATURE_DIM), device=device)
    mask = torch.ones((E, L_history), dtype=torch.bool, device=device)
    insert_idx=-1

    prev_actions = torch.full((E,), -1.0, device=device)
    prev_step_rewards = torch.zeros((E,), device=device)
    prev_day_reward_feature = torch.zeros((E,), device=device)

    orders_hist      = torch.zeros((E, N), device=device)
    drivers_hist     = torch.zeros((E, N), device=device)
    actions_hist     = torch.zeros((E, N), device=device)
    revenue_hist     = torch.zeros((E, N), device=device)
    ordersNext_hist  = torch.zeros((E, N), device=device)
    driversNext_hist = torch.zeros((E, N), device=device)

    for s in range(N):
        insert_idx += 1
        pos = insert_idx % L_history

        if insert_idx < L_history:
            idx_range = torch.arange(L_history, device=device).unsqueeze(0)
            mask = (idx_range > insert_idx).expand(E, -1)
        else:
            mask = torch.zeros((E, L_history), dtype=torch.bool, device=device)

        new_feat = make_feature_step_global(state[:,0], state[:,1], prev_actions, prev_step_rewards, prev_day_reward_feature, insert_idx, L_history)
        H[:, pos, :] = new_feat

        qvals = policy(H, key_padding_mask=mask)
        greedy_a = qvals.argmax(dim=1)
        if float(epsilon) <= 0.0:
            a_val = greedy_a
        elif float(epsilon) >= 1.0:
            a_val = torch.randint(0, 2, (E,), device=device)
        else:
            rand_mask = (torch.rand(E, device=device, generator=gen) < float(epsilon))
            rand_a = torch.randint(0, 2, (E,), device=device, generator=gen)
            a_val = torch.where(rand_mask, rand_a, greedy_a)

        t_in_day = s % times
        acts = a_val.unsqueeze(1).float()
        t_tensor = torch.full((E,1), float(t_in_day), device=device)

        X = torch.cat([state, acts, t_tensor], dim=1)
        next_mean = trans_model(X)
        rev_mean = revenue_model(X).squeeze(1)

        std_state = state_std_per_time[t_in_day].view(1,2)
        noise_state = torch.randn((E,2), device=device, generator=gen) * std_state.expand(E,2)
        std_reward = float(reward_std_per_time[t_in_day].item())
        noise_r = torch.randn((E,), device=device, generator=gen) * std_reward

        resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)
        resid_mean_r = float(revenue_mean_t[t_in_day].item())

        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()

        orders_hist[:, s]      = state[:, 0].view(-1)
        drivers_hist[:, s]     = state[:, 1].view(-1)
        actions_hist[:, s]     = a_val.float().view(-1)
        revenue_hist[:, s]     = reward.view(-1)
        ordersNext_hist[:, s]  = next_state[:, 0].view(-1)
        driversNext_hist[:, s] = next_state[:, 1].view(-1)

        prev_actions = a_val.clone().view(-1)
        prev_step_rewards = reward.clone().view(-1)
        state = next_state.clone()

        if (s % times) == (times - 1):
            day_idx = s // times
            end_col = s + 1
            data_for_ope = {
                'A': actions_hist[:, :end_col].clone(),
                'revenue': revenue_hist[:, :end_col].clone(),
                'orders': orders_hist[:, :end_col].clone(),
                'drivers': drivers_hist[:, :end_col].clone(),
                'ordersNext': ordersNext_hist[:, :end_col].clone(),
                'driversNext': driversNext_hist[:, :end_col].clone()
            }

            if day_idx < int(warmup_days):
                prev_day_reward_feature = torch.zeros((E,), device=device)
            else:
                out0 = Qeta_fn(data_for_ope, treatment=0, device=device)
                out1 = Qeta_fn(data_for_ope, treatment=1, device=device)
                eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
                eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1
                if not torch.is_tensor(eta_t0):
                    eta_t0 = torch.as_tensor(eta_t0, device=device, dtype=torch.float32)
                if not torch.is_tensor(eta_t1):
                    eta_t1 = torch.as_tensor(eta_t1, device=device, dtype=torch.float32)

                ATE_est_per_env = (eta_t0 - eta_t1).to(device)

                day_number = int(day_idx) + 1
                exponent = int(days) - day_number
                exponent = max(0, exponent)
                weight = float(decay_base) ** exponent
                ATE_true_tensor = torch.full_like(ATE_est_per_env, float(ATE_true_hat), device=device)
                R_day = - weight * (ATE_est_per_env - ATE_true_tensor) ** 2

                prev_day_reward_feature = R_day.clone().view(-1)

    data_for_ope_full = {
        'A': actions_hist, 'revenue': revenue_hist,
        'orders': orders_hist, 'drivers': drivers_hist,
        'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
    }
    out0 = Qeta_fn(data_for_ope_full, treatment=0, device=device)
    out1 = Qeta_fn(data_for_ope_full, treatment=1, device=device)
    eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
    eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1
    if not torch.is_tensor(eta_t0):
        eta_t0 = torch.as_tensor(eta_t0, device=device, dtype=torch.float32)
    if not torch.is_tensor(eta_t1):
        eta_t1 = torch.as_tensor(eta_t1, device=device, dtype=torch.float32)
    eta_diff = (eta_t0 - eta_t1).cpu().numpy()

    ate_list = eta_diff.tolist()
    ate_mean = float(np.mean(eta_diff))
    ate_se = float(np.std(eta_diff, ddof=1) / math.sqrt(max(1, E))) if E > 1 else 0.0
    mse = float(np.mean((eta_diff - float(ATE_true_hat))**2))
    bias2 = float((ate_mean - float(ATE_true_hat))**2)
    variance = float(np.var(eta_diff, ddof=1)) if E > 1 else 0.0

    return {
        'ate_list': ate_list,
        'ate_mean': ate_mean,
        'ate_se': ate_se,
        'mse': mse,
        'bias2': bias2,
        'variance': variance,
        'eta_diff': eta_diff,
        'data_for_ope': data_for_ope_full
    }
###

# -------------------- Main training pipeline --------------------
def main_train():
    NUM_ENVS = args.num_envs
    TOTAL_TIME_STEPS = args.times
    DATES = args.dates
    L_history = TOTAL_TIME_STEPS * DATES
    FEATURE_DIM = 8
    LEARN_EVERY = 3
    env_step_counter = 0 
    Seed = args.seed
    start_time = time.time()
    last_log_time = start_time

    (transition_model, revenue_model, initial_orders_values, initial_orders_probs,
     orders_resid_mean, orders_resid_var, drivers_resid_mean, drivers_resid_var,
     revenue_resid_mean, revenue_resid_var) = load_required_assets_or_die()

    resid_data = {'orders_mean': orders_resid_mean, 'orders_var': orders_resid_var,
                  'drivers_mean': drivers_resid_mean, 'drivers_var': drivers_resid_var,
                  'revenue_mean': revenue_resid_mean, 'revenue_var': revenue_resid_var}

    env = GenerativeEnvGPU(transition_model, revenue_model, NUM_ENVS,
                           initial_dist_data={'values': initial_orders_values, 'probs': initial_orders_probs},
                           residuals_data=resid_data)
    transition_model.eval()
    revenue_model.eval()
    policy = TransformerDQN(feature_dim=FEATURE_DIM, max_len=max(args.posenc_max_len, L_history)).to(device)
    ckpt = torch.load(args.transformer_path, map_location=device)
    state = ckpt.get("model", ckpt.get("state_dict", ckpt))
    policy.load_state_dict(state, strict=False)



    

    logger.info("Estimating ATE_true with paired Monte-Carlo (per-step)...")
    ATE_true_hat, _, se_init = estimate_true_ate_paired_mc_strict(
        transition_model, revenue_model, initial_orders_values, initial_orders_probs,
        resid_data, num_trajs=args.ate_mc_sims_init, T=TOTAL_TIME_STEPS, per_step=True, max_chunk_size=args.max_chunk_mc)
    logger.info(f"Initial ATE_true_hat (per-step) = {ATE_true_hat:.6e}  SE={se_init:.6e}")
    REP_count = 200
    TRL = run_policy_test(
    transition_model, revenue_model, policy,
    initial_orders_values, initial_orders_probs,
    resid_data, Q_eta_est_poly_tensor_batch,
    num_sims=REP_count, times=args.times, days=DATES,
    seed=Seed, epsilon=0.3, ATE_true_hat=ATE_true_hat,
    warmup_days=args.warmup_days, decay_base=0.8, device=device)
    TRL_mse = float(TRL['mse'])
    TRL_bias2 = float(TRL['bias2'])
    TRL_var = float(TRL['variance'])
    logger.info(f"[TRL_mse={ TRL_mse:.6e} | bias2={TRL_bias2:.6e} | var={TRL_var :.6e} | "
                )
    ate_list, ate_mean, ate_se = run_one_batch_random_sims_and_estimate(
        transition_model, revenue_model,
        initial_orders_values, initial_orders_probs,
        resid_data,
        Q_eta_est_poly_tensor_batch,
        NUM_ENVS=args.num_envs,
        times=args.times,
        days=DATES,
        num_sims=REP_count,
        seed=Seed,
        device=device
    )
    random_every = (ate_list - ATE_true_hat) ** 2
    arr = (random_every.detach().cpu().numpy() if isinstance(random_every, torch.Tensor) else np.asarray(random_every, dtype=float)).ravel()
    random_mse = random_every.mean()
    random_bias2 = (ate_list.mean() - ATE_true_hat) ** 2
    random_vari = ate_list.var()
    logger.info(f"Random policy metrics: MSE={random_mse:.6e} | bias^2={random_bias2:.6e} | variance={random_vari:.6e}")
    fs = get_divisors(TOTAL_TIME_STEPS)
    logger.info(f"TOTAL_TIME_STEPS={TOTAL_TIME_STEPS}, divisors to evaluate: {fs}")
    results = [] 
    for f in fs:
        logger.info(f"[SWEEP] evaluating f={f} (block length) ...")
        res = run_batch_generate_and_estimate(
            transition_model,
            revenue_model,
            initial_orders_values,
            initial_orders_probs,
            resid_data,
            Q_eta_est_poly_tensor_batch,
            num_sims=REP_count,
            times=TOTAL_TIME_STEPS,
            days=DATES,
            f=f,
            seed=Seed,
            ATE_true_hat=ATE_true_hat,
            device=device,
            save_csv_path=None
        )
        tmp = res.get('ate_list', res.get('eta_diff'))
        arr = tmp.detach().cpu().numpy() if isinstance(tmp, torch.Tensor) else np.asarray(tmp, dtype=float)
        mse_each = (arr - float(ATE_true_hat)) ** 2
        mse = float(res.get('mse', float('nan')))
        ate_mean = float(res.get('ate_mean', float('nan')))
        variance = float(res.get('variance', float('nan')))
        ate_se = float(res.get('ate_se', float('nan')))
        logger.info(f"[SWEEP:f={f}] mse={mse:.6e} | ate_mean={ate_mean:.6e} | var={variance:.6e} | ate_se={ate_se:.6e}")
        results.append({
            'f': int(f),
            'mse': mse,
            'ate_mean': ate_mean,
            'variance': variance,
            'ate_se': ate_se,
        })

    wager_ms = [m for m in fs if m > 1]
    logger.info(f"Evaluating Wager for m in: {wager_ms}")
    num_sims_each = 1
    reps_to_run = REP_count
    sim_size = 1
    xiong_summary = run_design_experiment(
        make_A_fn=lambda rng, N, T: np.random.randint(0,2,(N, T)),
        env=env,
        times=TOTAL_TIME_STEPS,
        days=DATES,
        num_sims=num_sims_each,
        sim_size=sim_size,
        reps=reps_to_run,
        seed=Seed,
        estimator_list=['xiong'],
        wager_ms=None,
        ATE_true=ATE_true_hat,
        device=device,
        save_prefix=None
    )
    logger.info(f"Xiong: mean={xiong_summary['xiong']['mean']:.6g} var={xiong_summary['xiong']['var']:.6g}")

    wager_results = {}
    for m in wager_ms:
        logger.info(f"Start evaluating Wager m={m} (reps={reps_to_run}) ...")
        s = run_design_experiment(
            make_A_fn=lambda rng, N, T: np.random.randint(0,2,(N, T)),
            env=env,
            times=TOTAL_TIME_STEPS,
            days=DATES,
            num_sims=num_sims_each,
            sim_size=sim_size,
            reps=reps_to_run,
            seed=Seed,
            estimator_list=['wager'],
            wager_ms=[m],
            ATE_true=ATE_true_hat,
            device=device,
            save_prefix=None
        )
        key = f"wager_m{m}"
        wager_results[key] = s[key]
        logger.info(f"  Wager m={m}: mean={s[key]['mean']:.6g} var={s[key]['var']:.6g} mse={s[key].get('mse', float('nan')):.6g}")

    bojinov_tis = fs[:]
    bojinov_res = {}
    for ti in bojinov_tis:
        logger.info(f"Start evaluating Bojinov IS-HT ti={ti} for {reps_to_run} reps ...")
        ests = np.full((reps_to_run,), np.nan, dtype=float)
        for r in range(reps_to_run):
            rng = np.random.default_rng(Seed + r)
            A_np = make_bojinov_design(ti)(rng, num_sims_each, TOTAL_TIME_STEPS)
            hist = simulate_env_with_action_matrix_binary(env, A_np, times=TOTAL_TIME_STEPS, days=DATES, device=device, seed=int(Seed + r))
            out = compute_bojinov_from_hist(hist, ti=ti, TI_loop=ti, sim_size=sim_size, times=TOTAL_TIME_STEPS, days=DATES, device=device)
            ests[r] = float(out[0])
        mean = float(np.nanmean(ests))
        var = float(np.nanvar(ests, ddof=1)) if ests.size > 1 else 0.0
        mse = float(np.nanmean((ests - float(ATE_true_hat))**2))
        bojinov_res[ti] = {'ests': ests, 'mean': mean, 'var': var, 'mse': mse}
        logger.info(f"  Bojinov ti={ti}: mean={mean:.6g} var={var:.6g} mse={mse:.6g}")

    logger.info("Start evaluating NMDP (multiple reps)...")
    nmdp_out = evaluate_nmdp_multiple_runs(
        trans_model=transition_model,
        revenue_model=revenue_model,
        initial_orders_values=initial_orders_values,
        initial_orders_probs=initial_orders_probs,
        resid_data=resid_data,
        pred_sigma_fn=fit_sigma_functions_NMDP_linear_binary,
        Qeta_fn=Q_eta_est_poly_tensor_batch,
        times=TOTAL_TIME_STEPS,
        days=DATES,
        num_sims=num_sims_each,
        reps=reps_to_run,
        seed=Seed,
        device=device,
        max_chunk_size=args.max_chunk_mc
    )
    logger.info(f"NMDP: mean={nmdp_out.get('mean', float('nan')):.6g} var={nmdp_out.get('var', float('nan')):.6g}")
    logger.info("Start evaluating TMDP (run_full_tmdp_pipeline_binary), doing reps loop ...")
    tmdp_ests = np.full((reps_to_run,), np.nan, dtype=float)
    for r in range(reps_to_run):
        res_t = run_full_tmdp_pipeline_binary(
            trans_model=transition_model,
            revenue_model=revenue_model,
            initial_orders_values=initial_orders_values,
            initial_orders_probs=initial_orders_probs,
            resid_data=resid_data,
            m0_pos=200,
            m0_neg=200,
            times=TOTAL_TIME_STEPS,
            days=DATES,
            num_sims=num_sims_each,
            seed=int(Seed + r),
            batch_size=64,
            device=device,
            Qeta_fn=Q_eta_est_poly_tensor_batch,
            max_chunk_size=args.max_chunk_mc
        )
        if isinstance(res_t, dict) and 'ate_mean' in res_t:
            tmdp_ests[r] = float(res_t['ate_mean'])
        elif isinstance(res_t, dict) and 'eta_diff' in res_t:
            tmdp_ests[r] = float(np.mean(res_t['eta_diff']))
        else:
            tmdp_ests[r] = float(res_t.get('ate_mean', res_t.get('mean', float('nan'))))
    tmdp_mean = float(np.nanmean(tmdp_ests))
    tmdp_var = float(np.nanvar(tmdp_ests, ddof=1)) if tmdp_ests.size > 1 else 0.0
    tmdp_mse = float(np.nanmean((tmdp_ests - float(ATE_true_hat))**2))
    logger.info(f"TMDP: mean={tmdp_mean:.6g} var={tmdp_var:.6g} mse={tmdp_mse:.6g}")
    

if __name__ == "__main__":
     main_train()
