# train_dqn_transformer_perstep_gpu.py

import os
import math
import time
import argparse
from pathlib import Path
from collections import deque, namedtuple
import logging
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

parser = argparse.ArgumentParser(description="Transformer-DQN full-history training (per-step ATE)")
parser.add_argument("--transition_model", type=str, default="checkpoints/transition_model_with_epoch_100.pth", help="checkpoint path for transition model")
parser.add_argument("--revenue_model",    type=str, default="checkpoints/revenue_model_with_epoch_100.pth", help="checkpoint path for revenue model")
parser.add_argument("--initial_dist",     type=str, default="initial_state_distribution.csv", help="initial state distribution csv")
parser.add_argument("--residual_csv",     type=str, default="nn_residual_analysis.csv", help="residual analysis csv")
parser.add_argument("--num_envs", type=int, default=128, help="number of parallel environments E")
parser.add_argument("--times",    type=int, default=20,  help="time steps per day")
parser.add_argument("--dates",    type=int, default=25,  help="history days (for L_history = times * dates)")
parser.add_argument("--batch_size",  type=int, default=256,   help="training batch size")
parser.add_argument("--replay_capacity", type=int, default=150000, help="replay buffer capacity")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument("--gamma", type=float, default=0.95, help="DQN discount factor")
parser.add_argument("--target_update_steps", type=int, default=1000, help="steps to update target network")
parser.add_argument("--eps_start", type=float, default=0.9, help="epsilon start value")
parser.add_argument("--eps_end",   type=float, default=0.1, help="epsilon end value")
parser.add_argument("--eps_decay_steps", type=int, default=500000, help="epsilon decay steps")
parser.add_argument("--warmup_days", type=int, default=20, help="warmup days: rewards are set to 0 before this")
parser.add_argument("--ate_mc_sims_init", type=int, default=20000, help="initial number of MC trajectories for ATE estimation")
parser.add_argument("--save_path", type=str, default="policy_dqn_transformer_perstep.pth", help="path to save policy")
parser.add_argument("--posenc_max_len", type=int, default=8192, help="max length for positional encoding (>= L_history)")
parser.add_argument("--max_chunk_mc", type=int, default=5000, help="max trajectories per chunk for MC (memory control)")
parser.add_argument("--epochs", type=int, default=2000, help="number of epochs (1 epoch = DATES*times steps)")
parser.add_argument("--gpu", type=int, default=5, help="GPU index to use (cuda:0).")
parser.add_argument("--seed", type=int, default=2026, help="random seed")

args = parser.parse_args()

def init_logger(name="DQN-Transformer-PerStep",
                log_dir=None,
                fname="train.log",
                level=logging.INFO,
                when="midnight",
                backup_count=7,
                overwrite_on_start=False):
    from pathlib import Path
    import logging, logging.handlers

    if log_dir is None:
        log_dir = Path.cwd() / "trainlog"
    else:
        log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    log_path = log_dir / fname

    if overwrite_on_start and log_path.exists():
        try:
            log_path.unlink()
        except Exception:
            open(log_path, "w").close()

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False
    for h in list(logger.handlers):
        logger.removeHandler(h)

    fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s", "%Y-%m-%d %H:%M:%S")
    ch = logging.StreamHandler(); ch.setLevel(level); ch.setFormatter(fmt); logger.addHandler(ch)

    fh = logging.handlers.TimedRotatingFileHandler(
        filename=str(log_path), when=when, backupCount=backup_count, encoding="utf-8", delay=True
    )
    fh.setLevel(level); fh.setFormatter(fmt); logger.addHandler(fh)

    logger.info(f"Logger initialized. logging to {log_path} (rotate={when}, backups={backup_count}, overwrite={overwrite_on_start})")
    return logger
logger = init_logger(name="DQN-Transformer-PerStep", overwrite_on_start=True)

assert torch.cuda.is_available(), "GPU required; modify for CPU if needed."
device = torch.device(f"cuda:{args.gpu}")

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

Transition = namedtuple('Transition', ('history', 'history_mask', 'action', 'reward', 'next_history', 'next_history_mask', 'done'))

class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=64):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
    def forward(self, x): return self.layers(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model, device=device)
        position = torch.arange(0, max_len, dtype=torch.float32, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        L = x.size(1)
        return x + self.pe[:L].unsqueeze(0)

class TransformerDQN(nn.Module):
    def __init__(self, feature_dim=8, d_model=128, nhead=4, num_layers=2, dim_feedforward=256, num_actions=2, dropout=0.1, max_len=8192):
        super().__init__()
        self.in_proj = nn.Linear(feature_dim, d_model).to(device)
        self.posenc = PositionalEncoding(d_model, max_len=max_len)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation="gelu"
        ).to(device)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers).to(device)
        self.head = nn.Sequential(
            nn.LayerNorm(d_model).to(device),
            nn.Linear(d_model, dim_feedforward).to(device),
            nn.GELU(),
            nn.Linear(dim_feedforward, num_actions).to(device)
        )
    def forward(self, x, key_padding_mask=None):
        x = self.in_proj(x)
        x = self.posenc(x)
        L = x.size(1)
        causal_mask = torch.triu(torch.ones(L, L, device=x.device), diagonal=1).bool()
        x = self.encoder(x, mask=causal_mask, src_key_padding_mask=key_padding_mask)
        last = x[:, -1, :]
        return self.head(last)


def map_state_dict_prefix(state_dict, from_prefix, to_prefix):
    mapped = {}
    for k, v in state_dict.items():
        if k.startswith(from_prefix):
            newk = to_prefix + k[len(from_prefix):]
        else:
            newk = k
        mapped[newk] = v
    return mapped


def robust_load_model(model, path, model_name="model"):
    sd_loaded = torch.load(str(path), map_location="cpu")
    if isinstance(sd_loaded, dict) and ('state_dict' in sd_loaded or 'model_state_dict' in sd_loaded):
        if 'state_dict' in sd_loaded:
            sd = sd_loaded['state_dict']
        else:
            sd = sd_loaded['model_state_dict']
    else:
        sd = sd_loaded
    keys = list(sd.keys())
    try:
        model.load_state_dict(sd)
        logger.info(f"{model_name}: strict load ok from {path}")
        return sd
    except Exception as e:
        logger.warning(f"{model_name}: strict load failed: {e}")

    try:
        res = model.load_state_dict(sd, strict=False)
        logger.info(f"{model_name}: loaded with strict=False; missing_keys={res.missing_keys}; unexpected_keys={res.unexpected_keys}")
        return sd
    except Exception as e:
        logger.warning(f"{model_name}: loose load failed: {e}")

    if any(k.startswith("module.") for k in keys) and not any(k.startswith("module.") for k in model.state_dict().keys()):
        sd2 = map_state_dict_prefix(sd, "module.", "")
        try:
            model.load_state_dict(sd2)
            logger.info(f"{model_name}: success after stripping 'module.'")
            return sd2
        except Exception as e:
            logger.warning(f"{model_name}: strip module failed: {e}")

    if any(k.startswith("layers.") for k in keys) and not any(k.startswith("net.") for k in keys):
        sd2 = map_state_dict_prefix(sd, "layers.", "net.")
        try:
            model.load_state_dict(sd2)
            logger.info(f"{model_name}: success mapping layers.->net.")
            return sd2
        except Exception as e:
            logger.warning(f"{model_name}: mapping layers->net failed: {e}")

    if any(k.startswith("net.") for k in keys) and not any(k.startswith("layers.") for k in keys):
        sd2 = map_state_dict_prefix(sd, "net.", "layers.")
        try:
            model.load_state_dict(sd2)
            logger.info(f"{model_name}: success mapping net.->layers.")
            return sd2
        except Exception as e:
            logger.warning(f"{model_name}: mapping net->layers failed: {e}")

    sample_k = keys[:80] if len(keys) > 80 else keys
    logger.error(f"{model_name}: failed to robustly load checkpoint {path}. sample keys: {sample_k}")
    raise RuntimeError(f"{model_name}: failed to load checkpoint {path}. See logs.")


def load_required_assets_or_die():
    if not Path(args.transition_model).exists(): 
        raise FileNotFoundError(f"Missing transition model: {args.transition_model}")
    if not Path(args.revenue_model).exists(): 
        raise FileNotFoundError(f"Missing revenue model: {args.revenue_model}")
    if not Path(args.initial_dist).exists(): 
        raise FileNotFoundError(f"Missing initial distribution: {args.initial_dist}")
    if not Path(args.residual_csv).exists(): 
        raise FileNotFoundError(f"Missing residual csv: {args.residual_csv}")

    transition_model = MLP(4,2).to(device)
    robust_load_model(transition_model, args.transition_model, model_name="transition_model")
    transition_model.eval()

    revenue_model = MLP(4,1).to(device)
    robust_load_model(revenue_model, args.revenue_model, model_name="revenue_model")
    revenue_model.eval()

    df_init = pd.read_csv(args.initial_dist)
    if not {'orders','probability'}.issubset(df_init.columns):
        raise ValueError("initial_state_distribution.csv must contain 'orders' and 'probability' columns.")
    initial_orders_values = torch.tensor(df_init['orders'].values, dtype=torch.float32, device=device)
    initial_orders_probs  = torch.tensor(df_init['probability'].values, dtype=torch.float32, device=device)
    initial_orders_probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)

    df_r = pd.read_csv(args.residual_csv)
    need_cols = [
        'ordersNext_resid_mean','ordersNext_resid_var',
        'driversNext_resid_mean','driversNext_resid_var',
        'revenue_resid_mean','revenue_resid_var'
    ]
    if not set(need_cols).issubset(df_r.columns):
        raise ValueError(f"residual csv missing columns: {need_cols}")
    orders_resid_mean = torch.tensor(df_r['ordersNext_resid_mean'].values, dtype=torch.float32, device=device)
    orders_resid_var  = torch.tensor(df_r['ordersNext_resid_var'].values, dtype=torch.float32, device=device)
    drivers_resid_mean= torch.tensor(df_r['driversNext_resid_mean'].values, dtype=torch.float32, device=device)
    drivers_resid_var = torch.tensor(df_r['driversNext_resid_var'].values, dtype=torch.float32, device=device)
    revenue_resid_mean= torch.tensor(df_r['revenue_resid_mean'].values, dtype=torch.float32, device=device)
    revenue_resid_var = torch.tensor(df_r['revenue_resid_var'].values, dtype=torch.float32, device=device)

    return (transition_model, revenue_model, initial_orders_values, initial_orders_probs,
            orders_resid_mean, orders_resid_var, drivers_resid_mean, drivers_resid_var,
            revenue_resid_mean, revenue_resid_var)


# -------------------- Strictly paired noise MC ATE (per-step optional, default per-step=True) --------------------
@torch.no_grad()
def estimate_true_ate_paired_mc_strict(trans_model, rev_model, initial_orders_values, initial_orders_probs, resid_data,
                                       round_reward=True, num_trajs=10000, T=20, per_step=True, max_chunk_size=5000):
    device_local = next(trans_model.parameters()).device
    N = int(num_trajs)

    probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)

    gen = torch.Generator(device=device_local)
    gen.manual_seed(2026)

    idx = torch.multinomial(probs, N, replacement=True, generator=gen)
    orders0_all = initial_orders_values[idx].unsqueeze(1).float().to(device_local)
    drivers0_all = torch.full((N, 1), 50.0, device=device_local)

    def pad_repeat(seq):
        s = seq.to(device_local).float()
        if s.shape[0] >= T: return s[:T]
        last = s[-1].unsqueeze(0).expand(T - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_seq = pad_repeat(resid_data['orders_mean'])
    drivers_mean_seq = pad_repeat(resid_data['drivers_mean'])
    orders_var_seq = pad_repeat(resid_data['orders_var'])
    drivers_var_seq = pad_repeat(resid_data['drivers_var'])
    revenue_mean_seq = pad_repeat(resid_data['revenue_mean'])
    revenue_var_seq = pad_repeat(resid_data['revenue_var'])

    state_std_per_t = torch.sqrt(torch.stack([orders_var_seq, drivers_var_seq], dim=1).clamp_min(1e-12))
    reward_std_per_t = torch.sqrt(revenue_var_seq.clamp_min(1e-12))

    diffs_chunks = []
    chunk = int(max(1, min(max_chunk_size, N)))
    for start in range(0, N, chunk):
        print(start)
        cur = min(chunk, N - start)
        orders0 = orders0_all[start:start + cur]
        drivers0 = drivers0_all[start:start + cur]

        noise_s_all = torch.randn((T, cur, 2), device=device_local, generator=gen) * state_std_per_t.view(T, 1, 2)
        noise_r_all = torch.randn((T, cur), device=device_local, generator=gen) * reward_std_per_t.view(T, 1)

        st_p = torch.cat([orders0.clone(), drivers0.clone()], dim=1)
        st_m = torch.cat([orders0.clone(), drivers0.clone()], dim=1)
        cum_p = torch.zeros((cur,), device=device_local)
        cum_m = torch.zeros((cur,), device=device_local)

        for t in range(T):
            a_p = torch.full((cur, 1), ((1.0 - 1.0) / 2.0), device=device_local)
            a_m = torch.full((cur, 1), ((1.0 - (-1.0)) / 2.0), device=device_local)
            t_e = torch.full((cur, 1), float(t), device=device_local)

            Xp = torch.cat([st_p, a_p, t_e], dim=1)
            Xm = torch.cat([st_m, a_m, t_e], dim=1)
            nm_p = trans_model(Xp); nm_m = trans_model(Xm)
            rm_p = rev_model(Xp).squeeze(1); rm_m = rev_model(Xm).squeeze(1)

            noise_s_t = noise_s_all[t]
            noise_r_t = noise_r_all[t]

            resid_mean_state_t = torch.stack([orders_mean_seq[t], drivers_mean_seq[t]], dim=0).unsqueeze(0)
            resid_mean_r_t = revenue_mean_seq[t]

            st_p = (nm_p + resid_mean_state_t + noise_s_t).round().clamp(min=0).float()
            st_m = (nm_m + resid_mean_state_t + noise_s_t).round().clamp(min=0).float()

            if round_reward:
                r_p = (rm_p + resid_mean_r_t + noise_r_t).round().clamp(min=0).view(-1).float()
                r_m = (rm_m + resid_mean_r_t + noise_r_t).round().clamp(min=0).view(-1).float()
            else:
                r_p = (rm_p + resid_mean_r_t + noise_r_t).view(-1).float()
                r_m = (rm_m + resid_mean_r_t + noise_r_t).view(-1).float()

            cum_p += r_p
            cum_m += r_m

        diffs_chunks.append(cum_p - cum_m)

    diffs = torch.cat(diffs_chunks, dim=0)

    ate_total = float(diffs.mean().item())
    se_total = diffs.std(unbiased=True).item() / math.sqrt(max(1, N))

    if per_step:
        diffs_per = diffs / float(T)
        ate_per = float(diffs_per.mean().item())
        se_per = diffs_per.std(unbiased=True).item() / math.sqrt(max(1, N))
        return ate_per, diffs_per, se_per

    return ate_total, diffs, se_total


# -------------------- Q_eta_est_poly_tensor_batch (OPE estimator) --------------------
def Q_eta_est_poly_tensor_batch(data, treatment, device=None):
    """
    Robust batch Q_eta / LSTDQ estimator (with improvements)
    Args:
      data: dict containing 'A','revenue','orders','drivers','ordersNext','driversNext'
            Each should be a 2D tensor or convertible array, shape (B, N) but function tolerates (N, B).
      treatment: scalar (e.g. 1 or 0), used to select treatment samples.
      device: torch.device or None (if None, inferred from first tensor).
    Returns:
      eta_est (B,)    : eta estimate per env
      TD_error (B,N)  : TD-like residual per env per time (only at mask positions)
      beta_a (B,P,1)  : regression coefficients (P=6)
    """
    import torch, logging

    required_keys = ['A','revenue','orders','drivers','ordersNext','driversNext']
    for k in required_keys:
        if k not in data:
            raise KeyError(f"Q_eta_est_poly_tensor_batch: missing key '{k}' in data")

    tensors = {}
    for k, v in data.items():
        if not torch.is_tensor(v):
            tensors[k] = torch.as_tensor(v)
        else:
            tensors[k] = v

    if device is None:
        device = next(iter(tensors.values())).device if len(tensors) > 0 else torch.device('cpu')

    for k in list(tensors.keys()):
        tensors[k] = tensors[k].to(device=device, dtype=torch.float32)

    A_raw = tensors['A']
    if A_raw.dim() != 2:
        raise ValueError(f"Q_eta_est_poly_tensor_batch: A must be 2D, got dim={A_raw.dim()}")

    B_raw, N_raw = A_raw.shape
    # transpose_flag = (B_raw < N_raw)
    # if transpose_flag:
    #     logging.getLogger("Qeta").info(f"Q_eta: auto-transpose inputs from {A_raw.shape} to (B,N)")
    #     for k in list(tensors.keys()):
    #         tensors[k] = tensors[k].t().contiguous()

    A = tensors['A']
    B, N = A.shape

    mask = (A == treatment)                       
    mask_f = mask.float().unsqueeze(-1)            
    revenue = tensors['revenue'].unsqueeze(-1)    
    orders = tensors['orders'].unsqueeze(-1)      
    drivers = tensors['drivers'].unsqueeze(-1)    
    ordersNext = tensors['ordersNext'].unsqueeze(-1)
    driversNext = tensors['driversNext'].unsqueeze(-1)

    S = torch.cat([orders, drivers], dim=-1)               
    Next_S = torch.cat([ordersNext, driversNext], dim=-1)  
    phi_S = torch.cat([S], dim=-1)             
    phi_next_S = torch.cat([Next_S], dim=-1)

    mask_sum_raw = mask.float().sum(dim=1)                 
    mask_sum = mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)  

    revenue_mean = (revenue * mask_f).sum(dim=1, keepdim=True) / mask_sum   
    revenue_c = (revenue - revenue_mean) * mask_f                            

    phi_mean = (phi_S * mask_f).sum(dim=1, keepdim=True) / mask_sum   
    phi_S_c = (phi_S - phi_mean) * mask_f                             
    phi_next_mean = (phi_next_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_next_S_c = (phi_next_S - phi_next_mean) * mask_f

    diff_phi_S_c = phi_S_c - phi_next_S_c    

    lhs = torch.matmul(diff_phi_S_c.transpose(1,2), phi_S_c)   
    rhs = torch.matmul(phi_S_c.transpose(1,2), revenue_c)     
    P = lhs.shape[1]
    regularization = torch.eye(P, device=device).unsqueeze(0)  
    try:
        beta_a = torch.linalg.solve(lhs + 3*regularization, rhs)   
    except RuntimeError:
        beta_list = []
        for i in range(B):
            Li = lhs[i]
            ri = rhs[i]
            try:
                bi = torch.linalg.solve(Li + 1e-6 * torch.eye(P, device=device), ri)
            except RuntimeError:
                bi = torch.linalg.lstsq(Li, ri).solution
            beta_list.append(bi)
        beta_a = torch.cat(beta_list, dim=0)

    Q_diff_vec = torch.matmul(diff_phi_S_c, beta_a).squeeze(-1)  
    numer = ((revenue.squeeze(-1) - Q_diff_vec) * mask.float()).sum(dim=1)  
    denom = mask_sum.squeeze(-1).squeeze(-1)  
    zero_mask = (mask_sum_raw == 0.0)
    denom_safe = denom.clone()
    denom_safe[zero_mask] = 1.0  

    eta_est = numer / denom_safe   
    if zero_mask.any():
        eta_est = eta_est.clone()
        eta_est[zero_mask] = 0.0   
    TD_error = (revenue.squeeze(-1) - Q_diff_vec - eta_est.unsqueeze(1)) * mask.float()
    return eta_est, TD_error, beta_a


@torch.no_grad()
def run_one_batch_random_sims_and_estimate(
    trans_model, revenue_model,
    initial_orders_values, initial_orders_probs,
    resid_data,
    Qeta_fn,
    NUM_ENVS=128,
    times=20,
    days=30,
    num_sims=100,
    seed=2025,
    device=None
):
    """
    Run num_sims simulations in parallel (each with NUM_ENVS trajectories).
    Each trajectory has length N = days * times.
    Collect (A,revenue,orders,drivers,ordersNext,driversNext) and feed to Qeta_fn.
    Returns:
      eta_diff (tensor)
      ate_mean (float)
      ate_se (float)
    """
    if device is None:
        device = next(trans_model.parameters()).device

    trans_model.eval()
    revenue_model.eval()

    initial_orders_values = initial_orders_values.to(device)
    initial_orders_probs  = initial_orders_probs.to(device)

    def pad_repeat(seq, L):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    N = int(days * times)

    E_total = int(num_sims)
    if E_total <= 0:
        raise ValueError("NUM_ENVS * num_sims must be > 0")

    gen = torch.Generator(device=device)
    gen.manual_seed(int(seed))

    probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)
    idx = torch.multinomial(probs, E_total, replacement=True, generator=gen)
    orders0 = initial_orders_values[idx].unsqueeze(1).float().to(device)
    drivers0 = torch.full((E_total,1), 50.0, device=device)
    state = torch.cat([orders0, drivers0], dim=1)

    orders_hist      = torch.zeros((E_total, N), device=device)
    drivers_hist     = torch.zeros((E_total, N), device=device)
    actions_hist     = torch.zeros((E_total, N), device=device)
    revenue_hist     = torch.zeros((E_total, N), device=device)
    ordersNext_hist  = torch.zeros((E_total, N), device=device)
    driversNext_hist = torch.zeros((E_total, N), device=device)

    state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
    reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

    for s in range(N):
        t_in_day = s % times

        a_val = torch.randint(0, 2, (E_total,), device=device, dtype=torch.long)

        actions_for_nn = a_val.unsqueeze(1).float()
        t_tensor = torch.full((E_total,1), float(t_in_day), device=device)

        X = torch.cat([state, actions_for_nn, t_tensor], dim=1)
        next_mean = trans_model(X)
        rev_mean = revenue_model(X).squeeze(1)

        std_state = state_std_per_time[t_in_day].view(1,2)
        noise_state = torch.randn((E_total,2), device=device, generator=gen) * std_state.expand(E_total,2)
        std_reward = float(reward_std_per_time[t_in_day].item())
        noise_r = torch.randn((E_total,), device=device, generator=gen) * std_reward

        resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)
        resid_mean_r = float(revenue_mean_t[t_in_day].item())

        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()

        orders_hist[:, s]      = state[:, 0].view(-1)
        drivers_hist[:, s]     = state[:, 1].view(-1)
        actions_hist[:, s]     = a_val.float().view(-1)
        revenue_hist[:, s]     = reward.view(-1)
        ordersNext_hist[:, s]  = next_state[:, 0].view(-1)
        driversNext_hist[:, s] = next_state[:, 1].view(-1)

        state = next_state

    data_for_ope = {
        'A': actions_hist, 'revenue': revenue_hist,
        'orders': orders_hist, 'drivers': drivers_hist,
        'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
    }

    eta_t0, _, _ = Qeta_fn(data_for_ope, treatment=0, device=device)
    eta_t1, _, _ = Qeta_fn(data_for_ope, treatment=1, device=device)

    eta_diff = (eta_t0 - eta_t1)
    ate_per_sim = eta_diff.mean().cpu().numpy()   

    ate_mean = float(ate_per_sim.mean()) if ate_per_sim.size > 0 else 0.0
    ate_se = float(ate_per_sim.std(ddof=1) / math.sqrt(max(1, ate_per_sim.size))) if ate_per_sim.size > 1 else 0.0

    return eta_diff, ate_mean, ate_se

def get_divisors(n):
    """Return all positive divisors of n (ascending order)."""
    small = []
    large = []
    i = 1
    while i * i <= n:
        if n % i == 0:
            small.append(i)
            if i != n // i:
                large.append(n // i)
        i += 1
    return sorted(small + large[::-1])


@torch.no_grad()
def run_batch_generate_and_estimate(
    trans_model, revenue_model,
    initial_orders_values, initial_orders_probs,
    resid_data,
    Qeta_fn,
    num_sims=100,
    times=20,
    days=30,
    f=1,
    seed=2025,
    ATE_true_hat=0.0,
    device=None,
    save_csv_path=None
):
    """
    Generate actions_big (num_sims, N), each row is a flattened trajectory, pairs are opposite.
    Vectorized rollout to get orders_hist, drivers_hist, actions_hist, revenue_hist, ordersNext_hist, driversNext_hist.
    Call Qeta_fn(treatment=0/1) to get eta for each row.
    Return dict with ate_list, ate_mean, ate_se, mse, bias2, variance, eta_diff (numpy).
    """
    if device is None:
        device = next(trans_model.parameters()).device

    trans_model.eval()
    revenue_model.eval()

    initial_orders_values = initial_orders_values.to(device)
    initial_orders_probs  = initial_orders_probs.to(device)

    N = int(days * times)
    E_total = int(num_sims)

    def pad_repeat(seq, L):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
    reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

    gen_init = torch.Generator(device=device); gen_init.manual_seed(int(seed))
    gen_step = torch.Generator(device=device); gen_step.manual_seed(int(seed + 1))

    probs = initial_orders_probs / (initial_orders_probs.sum() + 1e-12)
    idx_init = torch.multinomial(probs, E_total, replacement=True, generator=gen_init)
    orders0 = initial_orders_values[idx_init].unsqueeze(1).float().to(device)
    drivers0 = torch.full((E_total,1), 50.0, device=device)
    state = torch.cat([orders0, drivers0], dim=1)

    t_in_day = torch.arange(times, device=device)
    f_eff = max(1, int(f))
    base_block = ((t_in_day // f_eff) % 2).long()

    gen_phase = torch.Generator(device=device)
    gen_phase.manual_seed(int(seed))
    phases = torch.randint(0, 2, (E_total,), generator=gen_phase, device=device).long()

    base_day0 = (base_block.unsqueeze(0).expand(E_total, -1) ^ phases.unsqueeze(1)).long()

    parts = []
    for d in range(int(days)):
        if (d % 2) == 0:
            parts.append(base_day0)
        else:
            parts.append(1 - base_day0)
    actions_big = torch.cat(parts, dim=1).to(dtype=torch.long, device=device)

    if save_csv_path is not None:
        import pandas as _pd
        _pd.DataFrame(actions_big.cpu().numpy()).to_csv(save_csv_path, index=False)

    orders_hist = torch.zeros((E_total, N), device=device)
    drivers_hist = torch.zeros((E_total, N), device=device)
    actions_hist = torch.zeros((E_total, N), device=device)
    revenue_hist = torch.zeros((E_total, N), device=device)
    ordersNext_hist = torch.zeros((E_total, N), device=device)
    driversNext_hist = torch.zeros((E_total, N), device=device)

    for s in range(N):
        t_in_day = s % times

        a_val = actions_big[:, s]
        acts = a_val.unsqueeze(1).float()
        t_tensor = torch.full((E_total,1), float(t_in_day), device=device)

        X = torch.cat([state, acts, t_tensor], dim=1)
        next_mean = trans_model(X)
        rev_mean = revenue_model(X).squeeze(1)

        std_state = state_std_per_time[t_in_day].view(1,2)
        noise_state = torch.randn((E_total,2), device=device, generator=gen_step) * std_state.expand(E_total,2)

        std_reward = float(reward_std_per_time[t_in_day].item())
        noise_r = torch.randn((E_total,), device=device, generator=gen_step) * std_reward

        resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)

        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()

        resid_mean_r = float(revenue_mean_t[t_in_day].item())
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()

        orders_hist[:, s] = state[:, 0]
        drivers_hist[:, s] = state[:, 1]
        actions_hist[:, s] = a_val.float()
        revenue_hist[:, s] = reward
        ordersNext_hist[:, s] = next_state[:, 0]
        driversNext_hist[:, s] = next_state[:, 1]

        state = next_state

    data_for_ope = {
        'A': actions_hist, 'revenue': revenue_hist,
        'orders': orders_hist, 'drivers': drivers_hist,
        'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
    }

    out0 = Qeta_fn(data_for_ope, treatment=0, device=device)
    out1 = Qeta_fn(data_for_ope, treatment=1, device=device)

    eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
    eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1

    if not torch.is_tensor(eta_t0) or eta_t0.dim() != 1 or eta_t0.numel() != E_total:
        raise RuntimeError(f"eta_t0 shape unexpected: got {getattr(eta_t0,'shape',None)}, expected ({E_total},)")
    if not torch.is_tensor(eta_t1) or eta_t1.dim() != 1 or eta_t1.numel() != E_total:
        raise RuntimeError(f"eta_t1 shape unexpected: got {getattr(eta_t1,'shape',None)}, expected ({E_total},)")

    eta_diff = (eta_t0 - eta_t1).cpu().numpy()

    ate_list = eta_diff.tolist()
    ate_mean = float(np.mean(eta_diff))
    ate_se = float(np.std(eta_diff, ddof=1) / np.sqrt(max(1, E_total))) if E_total > 1 else 0.0
    mse = float(np.mean((eta_diff - float(ATE_true_hat))**2))
    bias2 = float((ate_mean - float(ATE_true_hat))**2)
    variance = float(np.var(eta_diff, ddof=1)) if E_total > 1 else 0.0

    return {
        'ate_list': ate_list,
        'ate_mean': ate_mean,
        'ate_se': ate_se,
        'mse': mse,
        'bias2': bias2,
        'variance': variance,
        'eta_diff': eta_diff,
        'data_for_ope': data_for_ope
    }

# -------------------- GPU Replay Buffer --------------------
class ReplayBufferGPU:
    def __init__(self, capacity:int):
        self.cap = int(capacity)
        self.buffer = deque(maxlen=self.cap)

    @torch.no_grad()
    def push_batch(self, S, A, R, NS, D, S_mask, NS_mask):
        # S/NS: (B,L,feat); mask: (B,L) bool; A/R/D: (B,)
        B = S.shape[0]
        for i in range(B):
            self.buffer.append((
                S[i].detach().to(device), S_mask[i].detach().to(device),
                int(A[i].item()), R[i].detach().to(device),
                NS[i].detach().to(device), NS_mask[i].detach().to(device),
                bool(bool(D[i].item()))
            ))

    def __len__(self): 
        return len(self.buffer)

    @torch.no_grad()
    def sample(self, batch_size:int):
        import random
        batch = random.sample(self.buffer, batch_size)
        s, sm, a, r, ns, nsm, d = zip(*batch)
        s   = torch.stack(s, dim=0)
        sm  = torch.stack(sm, dim=0)
        ns  = torch.stack(ns, dim=0)
        nsm = torch.stack(nsm, dim=0)
        a   = torch.tensor(a, dtype=torch.long, device=device)
        r_tensors = []
        for v in r:
            if torch.is_tensor(v):
                r_tensors.append(v.to(device))
            else:
                r_tensors.append(torch.tensor(v, device=device))
        r   = torch.stack(r_tensors).view(-1)
        d   = torch.tensor([1.0 if v else 0.0 for v in d], dtype=torch.float32, device=device)
        return s, a, r, ns, d, sm, nsm


# -------------------- Feature construction --------------------
def make_feature_step_global(orders, drivers, prev_action, prev_step_reward, prev_day_reward_feature, global_pos, L_history):
    E = orders.shape[0]
    o = orders.view(E,1).float()
    dr = drivers.view(E,1).float()
    pa = prev_action.view(E,1).float()
    psr= prev_step_reward.view(E,1).float()
    pdr= prev_day_reward_feature.view(E,1).float()
    pos_mod = float(global_pos % L_history)
    sin_g = torch.sin(2*math.pi*torch.full((E,1), pos_mod / max(1, L_history), device=device))
    cos_g = torch.cos(2*math.pi*torch.full((E,1), pos_mod / max(1, L_history), device=device))
    pos_norm = torch.full((E,1), pos_mod / max(1, L_history-1 if L_history>1 else 1.0), device=device)
    feat = torch.cat([o, dr, pa, psr, pdr, sin_g, cos_g, pos_norm], dim=1)
    return feat  # (E,8)


# -------------------- Generative Environment (GPU) --------------------
class GenerativeEnvGPU:
    def __init__(self, trans_model, rev_model, num_envs, initial_dist_data, residuals_data):
        self.trans_model = trans_model
        self.rev_model = rev_model
        self.device = device
        self.E = int(num_envs)
        self.init_vals = initial_dist_data['values']
        self.init_probs = initial_dist_data['probs']
        self.orders_resid_mean = residuals_data['orders_mean']
        self.orders_resid_var  = residuals_data['orders_var']
        self.drivers_resid_mean= residuals_data['drivers_mean']
        self.drivers_resid_var = residuals_data['drivers_var']
        self.revenue_resid_mean= residuals_data['revenue_mean']
        self.revenue_resid_var = residuals_data['revenue_var']
        self.gen = torch.Generator(device=device); self.gen.manual_seed(2025)
        self.state = None

    def reset(self):
        probs = self.init_probs / (self.init_probs.sum() + 1e-12)
        idx = torch.multinomial(probs, self.E, replacement=True, generator=self.gen)
        orders = self.init_vals[idx].unsqueeze(1).float()
        drivers = torch.full((self.E,1), 50.0, device=device)
        self.state = torch.cat([orders, drivers], dim=1)
        return self.state.clone()

    @torch.no_grad()
    def step(self, actions_val, t):
        E = self.E
        actions_for_nn = actions_val.unsqueeze(1)
        t_tensor = torch.full((E,1), float(t), device=device)
        X = torch.cat([self.state, actions_for_nn, t_tensor], dim=1)
        next_mean = self.trans_model(X)
        rev_mean = self.rev_model(X).squeeze(1)

        ti = min(t, self.orders_resid_mean.shape[0]-1)
        resid_mean_state = torch.stack([self.orders_resid_mean[ti], self.drivers_resid_mean[ti]], dim=0).unsqueeze(0)
        resid_var_state = torch.stack([self.orders_resid_var[ti], self.drivers_resid_var[ti]], dim=0).unsqueeze(0)
        std_state = torch.sqrt(resid_var_state.clamp_min(1e-9))
        noise_state = torch.randn((E,2), device=device, generator=self.gen) * std_state.expand(E,2)
        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()

        ti_r = min(t, self.revenue_resid_mean.shape[0]-1)
        resid_mean_r = self.revenue_resid_mean[ti_r]
        resid_std_r = math.sqrt(float(self.revenue_resid_var[ti_r].clamp_min(1e-9)))
        noise_r = torch.randn((E,), device=device, generator=self.gen) * resid_std_r
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()

        self.state = next_state.clone()
        done = torch.zeros((E,), dtype=torch.bool, device=device)
        return next_state.clone(), reward, done, {}


@torch.no_grad()
def run_policy_test(
    trans_model, revenue_model, policy,
    initial_orders_values, initial_orders_probs,
    resid_data,
    Qeta_fn,
    num_sims=100,
    times=20,
    days=30,
    seed=2025,
    epsilon=0.0,
    ATE_true_hat=None,
    warmup_days=0,
    decay_base=0.8,
    device=None
):
    if ATE_true_hat is None:
        raise ValueError("run_policy_test: ATE_true_hat is required (pass per-step true ATE).")
    if device is None:
        device = next(trans_model.parameters()).device
    trans_model.eval(); revenue_model.eval()

    def pad_repeat(seq, L):
        s = seq.to(device).float()
        if s.shape[0] >= L:
            return s[:L]
        last = s[-1:].unsqueeze(0).expand(L - s.shape[0])
        return torch.cat([s, last], dim=0)

    orders_mean_t  = pad_repeat(resid_data['orders_mean'], times)
    drivers_mean_t = pad_repeat(resid_data['drivers_mean'], times)
    orders_var_t   = pad_repeat(resid_data['orders_var'], times)
    drivers_var_t  = pad_repeat(resid_data['drivers_var'], times)
    revenue_mean_t = pad_repeat(resid_data['revenue_mean'], times)
    revenue_var_t  = pad_repeat(resid_data['revenue_var'], times)

    state_std_per_time = torch.sqrt(torch.stack([orders_var_t, drivers_var_t], dim=1).clamp_min(1e-12))
    reward_std_per_time = torch.sqrt(revenue_var_t.clamp_min(1e-12))

    N = int(days * times)
    E = int(num_sims)
    if E <= 0:
        raise ValueError("num_sims must be > 0")
    gen = torch.Generator(device=device); gen.manual_seed(int(seed))

    probs = initial_orders_probs.to(device) / (initial_orders_probs.sum() + 1e-12)
    idx = torch.multinomial(probs, E, replacement=True, generator=gen)
    orders0 = initial_orders_values.to(device)[idx].unsqueeze(1).float()
    drivers0 = torch.full((E,1), 50.0, device=device)
    state = torch.cat([orders0, drivers0], dim=1)

    FEATURE_DIM = 8
    L_history = N
    H = torch.zeros((E, L_history, FEATURE_DIM), device=device)
    mask = torch.ones((E, L_history), dtype=torch.bool, device=device)
    insert_idx=-1

    prev_actions = torch.full((E,), -1.0, device=device)
    prev_step_rewards = torch.zeros((E,), device=device)
    prev_day_reward_feature = torch.zeros((E,), device=device)

    orders_hist      = torch.zeros((E, N), device=device)
    drivers_hist     = torch.zeros((E, N), device=device)
    actions_hist     = torch.zeros((E, N), device=device)
    revenue_hist     = torch.zeros((E, N), device=device)
    ordersNext_hist  = torch.zeros((E, N), device=device)
    driversNext_hist = torch.zeros((E, N), device=device)

    for s in range(N):
        insert_idx += 1
        pos = insert_idx % L_history

        if insert_idx < L_history:
            idx_range = torch.arange(L_history, device=device).unsqueeze(0)
            mask = (idx_range > insert_idx).expand(E, -1)
        else:
            mask = torch.zeros((E, L_history), dtype=torch.bool, device=device)

        new_feat = make_feature_step_global(state[:,0], state[:,1], prev_actions, prev_step_rewards, prev_day_reward_feature, insert_idx, L_history)
        H[:, pos, :] = new_feat

        qvals = policy(H, key_padding_mask=mask)
        greedy_a = qvals.argmax(dim=1)
        if float(epsilon) <= 0.0:
            a_val = greedy_a
        elif float(epsilon) >= 1.0:
            a_val = torch.randint(0, 2, (E,), device=device)
        else:
            rand_mask = (torch.rand(E, device=device, generator=gen) < float(epsilon))
            rand_a = torch.randint(0, 2, (E,), device=device, generator=gen)
            a_val = torch.where(rand_mask, rand_a, greedy_a)

        t_in_day = s % times
        acts = a_val.unsqueeze(1).float()
        t_tensor = torch.full((E,1), float(t_in_day), device=device)

        X = torch.cat([state, acts, t_tensor], dim=1)
        next_mean = trans_model(X)
        rev_mean = revenue_model(X).squeeze(1)

        std_state = state_std_per_time[t_in_day].view(1,2)
        noise_state = torch.randn((E,2), device=device, generator=gen) * std_state.expand(E,2)
        std_reward = float(reward_std_per_time[t_in_day].item())
        noise_r = torch.randn((E,), device=device, generator=gen) * std_reward

        resid_mean_state = torch.stack([orders_mean_t[t_in_day], drivers_mean_t[t_in_day]], dim=0).unsqueeze(0)
        resid_mean_r = float(revenue_mean_t[t_in_day].item())

        next_state = (next_mean + resid_mean_state + noise_state).round().clamp(min=0).float()
        reward = (rev_mean + resid_mean_r + noise_r).round().clamp(min=0).float()

        orders_hist[:, s]      = state[:, 0].view(-1)
        drivers_hist[:, s]     = state[:, 1].view(-1)
        actions_hist[:, s]     = a_val.float().view(-1)
        revenue_hist[:, s]     = reward.view(-1)
        ordersNext_hist[:, s]  = next_state[:, 0].view(-1)
        driversNext_hist[:, s] = next_state[:, 1].view(-1)

        prev_actions = a_val.clone().view(-1)
        prev_step_rewards = reward.clone().view(-1)
        state = next_state.clone()

        if (s % times) == (times - 1):
            day_idx = s // times
            end_col = s + 1
            data_for_ope = {
                'A': actions_hist[:, :end_col].clone(),
                'revenue': revenue_hist[:, :end_col].clone(),
                'orders': orders_hist[:, :end_col].clone(),
                'drivers': drivers_hist[:, :end_col].clone(),
                'ordersNext': ordersNext_hist[:, :end_col].clone(),
                'driversNext': driversNext_hist[:, :end_col].clone()
            }

            if day_idx < int(warmup_days):
                prev_day_reward_feature = torch.zeros((E,), device=device)
            else:
                out0 = Qeta_fn(data_for_ope, treatment=0, device=device)
                out1 = Qeta_fn(data_for_ope, treatment=1, device=device)
                eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
                eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1
                if not torch.is_tensor(eta_t0):
                    eta_t0 = torch.as_tensor(eta_t0, device=device, dtype=torch.float32)
                if not torch.is_tensor(eta_t1):
                    eta_t1 = torch.as_tensor(eta_t1, device=device, dtype=torch.float32)

                ATE_est_per_env = (eta_t0 - eta_t1).to(device)

                day_number = int(day_idx) + 1
                exponent = int(days) - day_number
                exponent = max(0, exponent)
                weight = float(decay_base) ** exponent
                ATE_true_tensor = torch.full_like(ATE_est_per_env, float(ATE_true_hat), device=device)
                R_day = - weight * (ATE_est_per_env - ATE_true_tensor) ** 2

                prev_day_reward_feature = R_day.clone().view(-1)

    data_for_ope_full = {
        'A': actions_hist, 'revenue': revenue_hist,
        'orders': orders_hist, 'drivers': drivers_hist,
        'ordersNext': ordersNext_hist, 'driversNext': driversNext_hist
    }
    out0 = Qeta_fn(data_for_ope_full, treatment=0, device=device)
    out1 = Qeta_fn(data_for_ope_full, treatment=1, device=device)
    eta_t0 = out0[0] if isinstance(out0, (tuple, list)) else out0
    eta_t1 = out1[0] if isinstance(out1, (tuple, list)) else out1
    if not torch.is_tensor(eta_t0):
        eta_t0 = torch.as_tensor(eta_t0, device=device, dtype=torch.float32)
    if not torch.is_tensor(eta_t1):
        eta_t1 = torch.as_tensor(eta_t1, device=device, dtype=torch.float32)
    eta_diff = (eta_t0 - eta_t1).cpu().numpy()

    ate_list = eta_diff.tolist()
    ate_mean = float(np.mean(eta_diff))
    ate_se = float(np.std(eta_diff, ddof=1) / math.sqrt(max(1, E))) if E > 1 else 0.0
    mse = float(np.mean((eta_diff - float(ATE_true_hat))**2))
    bias2 = float((ate_mean - float(ATE_true_hat))**2)
    variance = float(np.var(eta_diff, ddof=1)) if E > 1 else 0.0

    return {
        'ate_list': ate_list,
        'ate_mean': ate_mean,
        'ate_se': ate_se,
        'mse': mse,
        'bias2': bias2,
        'variance': variance,
        'eta_diff': eta_diff,
        'data_for_ope': data_for_ope_full
    }


# -------------------- Main Training Loop --------------------
def main_train():
    NUM_ENVS = args.num_envs
    TOTAL_TIME_STEPS = args.times
    DATES = args.dates
    L_history = TOTAL_TIME_STEPS * DATES
    FEATURE_DIM = 8
    LEARN_EVERY = 3
    env_step_counter = 0
    Seed = args.seed
    start_time = time.time()
    last_log_time = start_time
    (transition_model, revenue_model, initial_orders_values, initial_orders_probs,
     orders_resid_mean, orders_resid_var, drivers_resid_mean, drivers_resid_var,
     revenue_resid_mean, revenue_resid_var) = load_required_assets_or_die()

    resid_data = {'orders_mean': orders_resid_mean, 'orders_var': orders_resid_var,
                  'drivers_mean': drivers_resid_mean, 'drivers_var': drivers_resid_var,
                  'revenue_mean': revenue_resid_mean, 'revenue_var': revenue_resid_var}

    env = GenerativeEnvGPU(transition_model, revenue_model, NUM_ENVS,
                           initial_dist_data={'values': initial_orders_values, 'probs': initial_orders_probs},
                           residuals_data=resid_data)
    policy = TransformerDQN(feature_dim=FEATURE_DIM, max_len=max(args.posenc_max_len, L_history)).to(device)
    target = TransformerDQN(feature_dim=FEATURE_DIM, max_len=max(args.posenc_max_len, L_history)).to(device)
    target.load_state_dict(policy.state_dict()); target.eval()
    optimizer = optim.AdamW(policy.parameters(), lr=args.lr)
    replay = ReplayBufferGPU(args.replay_capacity)
    transition_model.eval()
    revenue_model.eval()

    logger.info("Estimating ATE_true (per-step) with paired Monte Carlo ...")
    ATE_true_hat, _, se_init = estimate_true_ate_paired_mc_strict(
        transition_model, revenue_model, initial_orders_values, initial_orders_probs,
        resid_data, num_trajs=args.ate_mc_sims_init, T=TOTAL_TIME_STEPS, per_step=True, max_chunk_size=args.max_chunk_mc)
    logger.info(f"Initial ATE_true_hat (per-step) = {ATE_true_hat:.6e}  SE={se_init:.6e}")

    REP_count = 200
    ate_list, ate_mean, ate_se = run_one_batch_random_sims_and_estimate(
        transition_model, revenue_model,
        initial_orders_values, initial_orders_probs,
        resid_data,
        Q_eta_est_poly_tensor_batch,
        NUM_ENVS=args.num_envs,
        times=args.times,
        days=DATES,
        num_sims=REP_count,
        seed=Seed,
        device=device
    )
    random_every = (ate_list - ATE_true_hat) ** 2
    arr = (random_every.detach().cpu().numpy() if isinstance(random_every, torch.Tensor) else np.asarray(random_every, dtype=float)).ravel()
    pd.DataFrame({'sim_index': np.arange(arr.size), 'squared_error': arr}).to_csv("random_squared_errors.csv", index=False)
    logger.info(f"Random results saved to random_squared_errors.csv ({arr.size} rows)")
    random_mse = random_every.mean()
    random_bias2 = (ate_list.mean() - ATE_true_hat) ** 2
    random_vari = ate_list.var()
    logger.info(f"Random policy metrics: MSE={random_mse:.6e} | bias^2={random_bias2:.6e} | variance={random_vari:.6e}")

    fs = get_divisors(TOTAL_TIME_STEPS)
    logger.info(f"TOTAL_TIME_STEPS={TOTAL_TIME_STEPS}, divisors to evaluate: {fs}")
    results = []
    for f in fs:
        logger.info(f"[SWEEP] evaluating f={f} ...")
        res = run_batch_generate_and_estimate(
            transition_model,
            revenue_model,
            initial_orders_values,
            initial_orders_probs,
            resid_data,
            Q_eta_est_poly_tensor_batch,
            num_sims=REP_count,
            times=TOTAL_TIME_STEPS,
            days=DATES,
            f=f,
            seed=Seed,
            ATE_true_hat=ATE_true_hat,
            device=device,
            save_csv_path=None
        )
        tmp = res.get('ate_list', res.get('eta_diff'))
        arr = tmp.detach().cpu().numpy() if isinstance(tmp, torch.Tensor) else np.asarray(tmp, dtype=float)
        mse_each = (arr - float(ATE_true_hat)) ** 2
        pd.DataFrame({'mse_each': mse_each}).to_csv(f"mse_each_f{int(f)}.csv", index=False)
        logger.info(f"Saved mse_each_f{int(f)}.csv ({mse_each.size} rows) | mean_mse={mse_each.mean():.6e}")

        mse = float(res.get('mse', float('nan')))
        ate_mean = float(res.get('ate_mean', float('nan')))
        variance = float(res.get('variance', float('nan')))
        ate_se = float(res.get('ate_se', float('nan')))
        logger.info(f"[SWEEP:f={f}] mse={mse:.6e} | ate_mean={ate_mean:.6e} | var={variance:.6e} | ate_se={ate_se:.6e}")
        results.append({
            'f': int(f),
            'mse': mse,
            'ate_mean': ate_mean,
            'variance': variance,
            'ate_se': ate_se,
        })

    print(results)

    H = torch.zeros((NUM_ENVS, L_history, FEATURE_DIM), device=device)
    mask = torch.ones((NUM_ENVS, L_history), dtype=torch.bool, device=device)
    prev_actions = torch.full((NUM_ENVS,), -1.0, device=device)
    prev_step_rewards = torch.zeros((NUM_ENVS,), device=device)
    prev_day_reward_feature = torch.zeros((NUM_ENVS,), device=device)

    state = env.reset()
    global_step = 0
    insert_idx = -1
    opt_steps = 0
    decay_base = 0.95
    start_time = time.time()
    last_log_time = start_time

    def epsilon_by_step(step):
        if step >= args.eps_decay_steps: return args.eps_end
        frac = step / args.eps_decay_steps
        return args.eps_start + (args.eps_end - args.eps_start) * frac

    total_cols = TOTAL_TIME_STEPS * DATES
    orders_hist      = torch.zeros((NUM_ENVS, total_cols), device=device)
    drivers_hist     = torch.zeros((NUM_ENVS, total_cols), device=device)
    actions_hist     = torch.zeros((NUM_ENVS, total_cols), device=device)
    revenue_hist     = torch.zeros((NUM_ENVS, total_cols), device=device)
    ordersNext_hist  = torch.zeros((NUM_ENVS, total_cols), device=device)
    driversNext_hist = torch.zeros((NUM_ENVS, total_cols), device=device)

    best_test_mse = float('inf')
    best_test_epoch = -1

    for epoch in range(1, args.epochs + 1):
        policy.train()
        insert_idx = -1
        epoch_last_day_lastmoment_R = 0.0
        epoch_start = time.time()
        epoch_loss_sum = 0.0
        epoch_loss_cnt = 0

        insert_idx = -1
        H.zero_()
        mask = torch.ones((NUM_ENVS, L_history), dtype=torch.bool, device=device)
        prev_actions = torch.full((NUM_ENVS,), -1.0, device=device)
        prev_step_rewards = torch.zeros((NUM_ENVS,), device=device)
        prev_day_reward_feature = torch.zeros((NUM_ENVS,), device=device)
        state = env.reset()
        orders_hist.zero_()
        drivers_hist.zero_()
        actions_hist.zero_()
        revenue_hist.zero_()
        ordersNext_hist.zero_()
        driversNext_hist.zero_()

        logger.info(f"[EPOCH START] epoch={epoch}/{args.epochs} | current_global_step={global_step} | insert_idx={insert_idx}")
        for day in range(DATES):
            start_col = day * TOTAL_TIME_STEPS
            end_col   = start_col + TOTAL_TIME_STEPS
            for t in range(TOTAL_TIME_STEPS):
                insert_idx += 1
                pos = insert_idx % L_history
                if insert_idx < L_history:
                    idx_range = torch.arange(L_history, device=device).unsqueeze(0)
                    mask = (idx_range > insert_idx).expand(NUM_ENVS, -1)
                else:
                    mask = torch.zeros((NUM_ENVS, L_history), dtype=torch.bool, device=device)

                new_feat = make_feature_step_global(state[:,0], state[:,1], prev_actions, prev_step_rewards, prev_day_reward_feature, insert_idx, L_history)
                H[:, pos, :] = new_feat

                eps = epsilon_by_step(global_step)
                with torch.no_grad():
                    qvals = policy(H, key_padding_mask=mask)
                greedy = qvals.argmax(dim=1)
                rand_mask = (torch.rand(NUM_ENVS, device=device) < eps)
                rand_a = torch.randint(0, 2, (NUM_ENVS,), device=device)
                a_val = torch.where(rand_mask, rand_a, greedy)
                next_state, reward_step, done, _ = env.step(a_val, t)

                col_idx = start_col + t
                orders_hist[:, col_idx]      = state[:, 0].view(-1)
                drivers_hist[:, col_idx]     = state[:, 1].view(-1)
                actions_hist[:, col_idx]     = a_val.float().view(-1)
                revenue_hist[:, col_idx]     = reward_step.view(-1)
                ordersNext_hist[:, col_idx]  = next_state[:, 0].view(-1)
                driversNext_hist[:, col_idx] = next_state[:, 1].view(-1)

                pos_next = (insert_idx + 1) % L_history
                NH = H.clone()
                next_feat = make_feature_step_global(next_state[:,0], next_state[:,1], a_val, reward_step, prev_day_reward_feature, insert_idx+1, L_history)
                NH[:, pos_next, :] = next_feat
                if (insert_idx+1) < L_history:
                    idx_range = torch.arange(L_history, device=device).unsqueeze(0)
                    nmask = (idx_range > (insert_idx+1)).expand(NUM_ENVS, -1)
                else:
                    nmask = torch.zeros((NUM_ENVS, L_history), dtype=torch.bool, device=device)

                if t < TOTAL_TIME_STEPS - 1:
                    R_day = torch.zeros((NUM_ENVS,), device=device)
                else:
                    if day < args.warmup_days:
                        R_day = torch.zeros((NUM_ENVS,), device=device)
                        prev_day_reward_feature = torch.zeros_like(prev_day_reward_feature)
                    else:
                        A_cum      = actions_hist[:, 0:end_col]
                        revenue_cum= revenue_hist[:, 0:end_col]
                        orders_cum = orders_hist[:, 0:end_col]
                        drivers_cum= drivers_hist[:, 0:end_col]
                        ordersNext_cum = ordersNext_hist[:, 0:end_col]
                        driversNext_cum = driversNext_hist[:, 0:end_col]
                        data_for_ope = {
                            'A': A_cum.clone(), 'revenue': revenue_cum.clone(),
                            'orders': orders_cum.clone(), 'drivers': drivers_cum.clone(),
                            'ordersNext': ordersNext_cum.clone(), 'driversNext': driversNext_cum.clone()
                        }
                        eta0, _, _ = Q_eta_est_poly_tensor_batch(data_for_ope, treatment=0, device=device)
                        eta1, _, _ = Q_eta_est_poly_tensor_batch(data_for_ope, treatment=1, device=device)
                        ATE_est = (eta0 - eta1)
                        ATE_true_tensor = torch.full_like(ATE_est, float(ATE_true_hat))
                        R_day = - (float(decay_base) ** max(0, int(DATES) - int(day) - 1)) * (ATE_est - ATE_true_tensor) ** 2
                        prev_day_reward_feature = R_day.clone()

                done_flag = torch.zeros((NUM_ENVS,), dtype=torch.float32, device=device)
                replay.push_batch(H.detach(), a_val.detach().float(), R_day.detach(), NH.detach(), done_flag.detach(), mask.detach(), nmask.detach())

                prev_actions = a_val.clone().view(-1)
                prev_step_rewards = reward_step.clone().view(-1)
                state = next_state.clone()

                global_step += 10
                env_step_counter += 1
                if len(replay) >= max(500*args.batch_size, 10000) and (env_step_counter % LEARN_EVERY == 0):
                    s_b, a_b, r_b, ns_b, d_b, sm_b, nsm_b = replay.sample(args.batch_size)
                    q_all = policy(s_b, key_padding_mask=sm_b)
                    q_pred_a = q_all.gather(1, a_b.view(-1,1)).squeeze(1)
                    with torch.no_grad():
                        non_terminal = (d_b < 0.5)
                        q_next = torch.zeros(args.batch_size, device=device)
                        if non_terminal.any():
                            q_next_online = policy(ns_b[non_terminal], key_padding_mask=nsm_b[non_terminal])
                            a_next = q_next_online.argmax(dim=1, keepdim=True)
                            q_next_target = target(ns_b[non_terminal], key_padding_mask=nsm_b[non_terminal])
                            q_next_sel = q_next_target.gather(1, a_next).squeeze(1)
                            q_next[non_terminal] = q_next_sel
                        target_vals = r_b.to(device) + (1.0 - d_b.to(device)) * args.gamma * q_next
                    loss = F.mse_loss(q_pred_a, target_vals)
                    optimizer.zero_grad(set_to_none=True)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
                    optimizer.step()
                    epoch_loss_sum += float(loss.detach().item())
                    epoch_loss_cnt += 1
                    opt_steps += 1
                    if (opt_steps % args.target_update_steps) == 0:
                        target.load_state_dict(policy.state_dict())

        if (day == DATES - 1) and (t == TOTAL_TIME_STEPS - 1):
            epoch_last_day_lastmoment_R = float(R_day.mean().item())
        avg_loss = epoch_loss_sum / max(1, epoch_loss_cnt)
        epoch_elapsed = time.time() - epoch_start
        total_elapsed = time.time() - start_time
        logger.info(
            f"[EPOCH END] epoch={epoch}/{args.epochs} | step={global_step} | replay={len(replay)} | "
            f"opt_steps={opt_steps} | epoch_time={epoch_elapsed:.1f}s | total_time={total_elapsed:.1f}s | "
            f"eps={epsilon_by_step(global_step):.3f} | insert_idx={insert_idx} | "
            f"ATE_true_hat(per-step)={ATE_true_hat:.6e} | epoch_avg_loss={avg_loss:.6e} | epoch_loss_steps={epoch_loss_cnt} | "
            f"last_day_lastmoment_R={epoch_last_day_lastmoment_R:.6e}"
        )
        if best_test_epoch > 0:
            logger.info(f"[BEST SO FAR] best_test_mse={best_test_mse:.6e} achieved_at_epoch={best_test_epoch}")
        else:
            logger.info(f"[BEST SO FAR] best_test_mse=inf (no successful test saved yet)")

        if (epoch % 2) == 0:
            try:
                test_res = run_policy_test(
                    transition_model, revenue_model, policy,
                    initial_orders_values, initial_orders_probs,
                    resid_data, Q_eta_est_poly_tensor_batch,
                    num_sims=REP_count, times=args.times, days=DATES,
                    seed=Seed, epsilon=0.3, ATE_true_hat=ATE_true_hat,
                    warmup_days=args.warmup_days, decay_base=decay_base, device=device
                )
                test_mse = float(test_res['mse'])
                test_bias2 = float(test_res['bias2'])
                test_variance = float(test_res['variance'])
                logger.info(
                    f"[TEST@epoch{epoch}] mse={test_mse:.6e} | bias2={test_bias2:.6e} | var={test_variance:.6e} | "
                )
                if test_mse < best_test_mse:
                    best_test_mse = test_mse
                    best_test_epoch = epoch
                    ckpt = {
                        'epoch': epoch,
                        'model_state_dict': policy.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'test_mse': test_mse
                    }
                    best_model_path=f'checkpoints/best_model_epoch_{epoch}_{test_mse}.pth'
                    torch.save(ckpt, best_model_path)
                    logger.info(f"[NEW BEST] epoch={epoch} saved best model -> {best_model_path} (test_mse={test_mse:.6e})")
                    ate_raw = test_res.get('ate_list', test_res.get('eta_diff'))
                    ate_arr = ate_raw.detach().cpu().numpy() if isinstance(ate_raw, torch.Tensor) else np.asarray(ate_raw, dtype=float)
                    mse_each = (ate_arr - float(ATE_true_hat))**2
                    out_name = f"mse_each_epoch{int(epoch)}_mse{test_mse}.csv"
                    pd.DataFrame({'mse_each': mse_each}).to_csv(out_name, index=False)
                    logger.info(f"Saved {out_name} n={mse_each.size} mean_mse={mse_each.mean():.6e}")
            except Exception as e:
                logger.warning(f"[TEST ERROR] run_policy_test failed at epoch={epoch}: {e}")


# -------------- Entry --------------
if __name__ == "__main__":
     main_train()