import torch
import numpy as np
import pandas as pd
import time
import random
from torch.distributions.normal import Normal
import statsmodels.api as sm

# ==============================================================================
# 0. Seeding for Reproducibility
# ==============================================================================
def set_seed(seed: int):
    """Set all relevant random seeds to ensure reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Seed fixed at: {seed}")

# ==============================================================================
# 1. Global Constants
# ==============================================================================
NUMBER_OF_GRID_TILES_X = 9
NUMBER_OF_GRID_TILES_Y = 9
NUMBER_OF_TIME_STEPS = 20
NUMBER_OF_ORDERS = 100
NUMBER_OF_DRIVERS = 25
MAX_MANHATTAN_DISTANCE = 2
DISCOUNT_FACTOR = 0.9
BASE_REWARD_PER_TRIP = 1.0
REWARD_FOR_DISTANCE_PARAMETER = 1.0

V_DIMS = (NUMBER_OF_GRID_TILES_X, NUMBER_OF_GRID_TILES_Y, NUMBER_OF_TIME_STEPS)

# ==============================================================================
# 2. GPU-based Data Generators
# ==============================================================================
def truncated_normal_rvs_pt(mean, std, lower_bound, upper_bound, size):
    device = mean.device
    normal_dist = Normal(torch.tensor(0.0, device=device), torch.tensor(1.0, device=device))
    alpha = (lower_bound - mean) / std
    beta = (upper_bound - mean) / std
    cdf_alpha = normal_dist.cdf(alpha)
    cdf_beta = normal_dist.cdf(beta)
    p = cdf_alpha + torch.rand(size, device=device) * (cdf_beta - cdf_alpha)
    p = p.clamp(1e-9, 1 - 1e-9)
    x_std_truncated = normal_dist.icdf(p)
    return x_std_truncated * std + mean

def generate_initial_orders_from_mixture_pt(BATCH_SIZE, NUMBER_OF_ORDERS, device):
    total_samples = BATCH_SIZE * NUMBER_OF_ORDERS
    weights = torch.tensor([1./3, 2./3], device=device)
    choices = torch.multinomial(weights, total_samples, replacement=True)
    
    mean1 = torch.tensor([3.0, 3.0, 2.0], device=device)
    std1 = torch.tensor([2.0, 2.0, 2.0], device=device)
    mean2 = torch.tensor([6.0, 6.0, 14.0], device=device)
    std2 = torch.tensor([2.0, 2.0, 2.0], device=device)
    lower = torch.tensor([0.0, 0.0, 0.0], device=device)
    upper = torch.tensor([8.0, 8.0, 19.0], device=device)

    means = torch.where((choices == 0).unsqueeze(1), mean1, mean2)
    stds = torch.where((choices == 0).unsqueeze(1), std1, std2)

    generated_data = truncated_normal_rvs_pt(
        means, stds, lower.expand_as(means), upper.expand_as(means), size=(total_samples, 3)
    )
    return generated_data.round().long().view(BATCH_SIZE, NUMBER_OF_ORDERS, 3)

def generate_initial_waiting_times_pt(BATCH_SIZE, NUMBER_OF_ORDERS, device):
    total_samples = BATCH_SIZE * NUMBER_OF_ORDERS
    mean_t = torch.tensor(1.5, device=device)
    std_t = torch.tensor(0.5, device=device)
    lower_t = torch.tensor(0.0, device=device)
    upper_t = torch.tensor(3.0, device=device)
    
    waiting_times_float = truncated_normal_rvs_pt(
        mean_t, std_t, lower_t, upper_t, size=(total_samples,)
    )
    return waiting_times_float.round().long().view(BATCH_SIZE, NUMBER_OF_ORDERS)

# ==============================================================================
# 3. Core Matching Algorithm
# ==============================================================================
def greedy_decode_from_sinkhorn_batched(T):
    b, n, m = T.shape
    k = min(n, m)
    
    matched_rows = torch.full((b, k), -1, dtype=torch.long, device=T.device)
    matched_cols = torch.full((b, k), -1, dtype=torch.long, device=T.device)
    
    T_copy = T.clone()

    for i in range(k):
        if T_copy.view(b, -1).shape[1] == 0:
            break
        
        max_probs, flat_indices = torch.max(T_copy.view(b, -1), dim=1)
        active_mask = max_probs > 1e-9
        
        if not active_mask.any():
            break
            
        row_indices_active = flat_indices[active_mask] // m
        col_indices_active = flat_indices[active_mask] % m
        
        matched_rows[active_mask, i] = row_indices_active
        matched_cols[active_mask, i] = col_indices_active
        
        T_copy[active_mask, row_indices_active, :] = -1.0
        T_copy[active_mask, :, col_indices_active] = -1.0

    return matched_rows, matched_cols

# ==============================================================================
# 4. Off-Policy Evaluation (OPE) Function
# ==============================================================================
def Q_eta_est_poly_tensor_batch(data, treatment, device="cuda"):
    """Fully tensorized LSTDQ off-policy evaluator."""
    B, N = data['A'].shape
    
    mask = (data['A'] == treatment)
    mask_f = mask.float().unsqueeze(-1)

    revenue = data['revenue'].to(device).unsqueeze(-1)
    orders = data['orders'].to(device).unsqueeze(-1)
    drivers = data['drivers'].to(device).unsqueeze(-1)
    ordersNext = data['ordersNext'].to(device).unsqueeze(-1)
    driversNext = data['driversNext'].to(device).unsqueeze(-1)

    S = torch.cat([orders, drivers], dim=-1)
    Next_S = torch.cat([ordersNext, driversNext], dim=-1)

    # Feature expansion
    phi_S = torch.cat([S, S**2, S**3], dim=-1)
    phi_next_S = torch.cat([Next_S, Next_S**2, Next_S**3], dim=-1)

    # Centering
    mask_sum = mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)
    revenue_mean = (revenue * mask_f).sum(dim=1, keepdim=True) / mask_sum
    revenue_c = (revenue - revenue_mean) * mask_f

    phi_mean = (phi_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_S_c = (phi_S - phi_mean) * mask_f

    phi_next_mean = (phi_next_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_next_S_c = (phi_next_S - phi_next_mean) * mask_f

    diff_phi_S_c = phi_S_c - phi_next_S_c

    # Solve linear system
    lhs = torch.matmul(diff_phi_S_c.transpose(1, 2), phi_S_c)
    rhs = torch.matmul(phi_S_c.transpose(1, 2), revenue_c)
    
    regularization = 1e-5 * torch.eye(lhs.shape[1], device=lhs.device).unsqueeze(0)
    try:
        beta_a = torch.linalg.solve(lhs + regularization, rhs)
    except torch.linalg.LinAlgError:
        print("Warning: linear system solve failed, possibly singular matrix. Returning zero vector.")
        beta_a = torch.zeros_like(rhs)

    Q_diff_vec = torch.matmul(diff_phi_S_c, beta_a)
    eta_est = ((revenue - Q_diff_vec) * mask_f).sum(dim=1) / mask_sum.squeeze(-1)
    eta_est = eta_est.squeeze(-1)

    TD_error = (revenue - Q_diff_vec - eta_est.unsqueeze(1).unsqueeze(-1)) * mask_f
    TD_error = TD_error.squeeze(-1)

    return eta_est, TD_error, beta_a


# ==============================================================================
# 5. Main Simulator (GPU Version)
# ==============================================================================
def run_simulation_and_ope(
    BATCH_SIZE, NUM_EPISODES, V_tensor, device, method, 
    run_ope=False
):
    print(f"Running simulation: BATCH_SIZE={BATCH_SIZE}, NUM_EPISODES={NUM_EPISODES}, METHOD={method.upper()}...")
    B, E = BATCH_SIZE, NUM_EPISODES
    X, Y, T = NUMBER_OF_GRID_TILES_X, NUMBER_OF_GRID_TILES_Y, NUMBER_OF_TIME_STEPS
    BE = B * E
    
    V_be = V_tensor.to(device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0).expand(B, E, -1, -1, -1).contiguous()
    V_flat_lookup = V_be.view(B, E, -1)
    def v_lookup(xs, ys, ts):
        lin = xs.long() + X * ys.long() + (X * Y) * ts.long()
        return V_flat_lookup.gather(-1, lin.view(B, E, -1)).view_as(xs)

    orders = torch.zeros(B, E, NUMBER_OF_ORDERS, 7, device=device, dtype=torch.float32)
    st_data = generate_initial_orders_from_mixture_pt(BE, NUMBER_OF_ORDERS, device).view(B, E, NUMBER_OF_ORDERS, 3).float()
    wait_data = generate_initial_waiting_times_pt(BE, NUMBER_OF_ORDERS, device).view(B, E, NUMBER_OF_ORDERS).float()
    dest_data = torch.stack([
        torch.randint(0, X, size=(B, E, NUMBER_OF_ORDERS), device=device),
        torch.randint(0, Y, size=(B, E, NUMBER_OF_ORDERS), device=device)
    ], dim=-1).float()

    orders[..., 1:4] = st_data; orders[..., 4] = wait_data; orders[..., 5:7] = dest_data

    drivers = torch.zeros(B, E, NUMBER_OF_DRIVERS, 3, device=device, dtype=torch.float32)
    drivers[..., 1] = torch.randint(0, X, size=drivers[..., 1].shape, device=device).float()
    drivers[..., 2] = torch.randint(0, Y, size=drivers[..., 2].shape, device=device).float()

    b_grid, e_grid = torch.meshgrid(torch.arange(B, device=device), torch.arange(E, device=device), indexing='ij')
    all_rows_list = []

    for t in range(T):
        active_orders_mask = (orders[..., 0] == 0) & (orders[..., 3] <= t) & ((orders[..., 3] + orders[..., 4]) >= t)
        available_drv_mask = drivers[..., 0] <= t
        
        o_pos = orders[..., 1:3]; d_pos = drivers[..., 1:3]; dstpos = orders[..., 5:7]
        dist_mat = torch.cdist(o_pos.reshape(BE, NUMBER_OF_ORDERS, 2), d_pos.reshape(BE, NUMBER_OF_DRIVERS, 2), p=1).view(B, E, NUMBER_OF_ORDERS, NUMBER_OF_DRIVERS)
        allowed = dist_mat <= MAX_MANHATTAN_DISTANCE
        
        valid_mask = active_orders_mask.unsqueeze(-1) & available_drv_mask.unsqueeze(-2) & allowed
        
        trip_dist = (o_pos - dstpos).abs().sum(dim=-1)
        delta_t = (1 + dist_mat + trip_dist.unsqueeze(-1)).long()
        reward = BASE_REWARD_PER_TRIP + REWARD_FOR_DISTANCE_PARAMETER * trip_dist
        
        curV = v_lookup(d_pos[..., 0], d_pos[..., 1], torch.full_like(d_pos[..., 0], t)).unsqueeze(-2)
        fut_t = (t + delta_t).clamp_max(T - 1)
        
        dst_x_expanded = dstpos[..., 0].unsqueeze(-1).expand_as(fut_t)
        dst_y_expanded = dstpos[..., 1].unsqueeze(-1).expand_as(fut_t)
        futV = v_lookup(dst_x_expanded, dst_y_expanded, fut_t)
        
        delta_t_f = delta_t.clamp_min(1).float()
        gamma = torch.tensor(DISCOUNT_FACTOR, device=device, dtype=delta_t_f.dtype)
        disc_dt = torch.pow(gamma, delta_t_f)
        imm = (1 - disc_dt).clamp_min(1e-9) / (1 - gamma).clamp_min(1e-9) * reward.unsqueeze(-1).to(delta_t_f.dtype) / delta_t_f
        adv_mdp = imm + disc_dt * futV - curV
        
        adv_mdp_masked = torch.where(valid_mask, adv_mdp, -1e9)
        adv_dist_masked = torch.where(valid_mask, -dist_mat.float(), -1e9)
        
        A_val_policy = torch.zeros(B, E, device=device, dtype=torch.bool)
        if method == 'mdp': 
            adv = adv_mdp_masked
            A_val_policy[:] = True
        elif method == 'distance': 
            adv = adv_dist_masked
            A_val_policy[:] = False
        elif method == 'random':
             bern = (torch.rand(B, E, 1, 1, device=device) < 0.5)
             adv  = torch.where(bern, adv_mdp_masked, adv_dist_masked)
             A_val_policy = bern.squeeze()
        
        K = torch.exp(torch.clamp(adv, max=50) / 0.1)
        K = torch.where(valid_mask, K, torch.zeros_like(K))
        T_flat = K.view(BE, NUMBER_OF_ORDERS, NUMBER_OF_DRIVERS)
        mr, mc = greedy_decode_from_sinkhorn_batched(T_flat)
        
        valid_pairs_mask = mr >= 0
        num_matches = valid_pairs_mask.sum()
        matched_revenue_per_be = torch.zeros(BE, device=device, dtype=torch.float32)
        
        if num_matches > 0:
            be_idx_flat = torch.arange(BE, device=device).unsqueeze(1).expand_as(mr)
            be_sel = be_idx_flat[valid_pairs_mask]
            r_sel = mr[valid_pairs_mask]; c_sel = mc[valid_pairs_mask]
            b_idx = be_sel // E; e_idx = be_sel % E
            
            trip = (orders[b_idx, e_idx, r_sel, 1:3] - orders[b_idx, e_idx, r_sel, 5:7]).abs().sum(dim=-1).long()
            revenues = BASE_REWARD_PER_TRIP + REWARD_FOR_DISTANCE_PARAMETER * trip.float()
            matched_revenue_per_be.scatter_add_(0, be_sel, revenues)
            
            orders[b_idx, e_idx, r_sel, 0] = 1.0
            pickup = (orders[b_idx, e_idx, r_sel, 1:3] - drivers[b_idx, e_idx, c_sel, 1:3]).abs().sum(dim=-1).long()
            dt = (1 + pickup + trip).long()
            drivers[b_idx, e_idx, c_sel, 0] = t + dt.float()
            drivers[b_idx, e_idx, c_sel, 1:3] = orders[b_idx, e_idx, r_sel, 5:7]

        orders_now_cnt = active_orders_mask.sum(dim=-1)
        drivers_now_cnt = available_drv_mask.sum(dim=-1)
        t_next = t + 1
        if t_next < T:
            orders_next_cnt = ((orders[..., 0] == 0) & (orders[..., 3] <= t_next) & ((orders[..., 3] + orders[..., 4]) >= t_next)).sum(dim=-1)
            drivers_next_cnt = (drivers[..., 0] <= t_next).sum(dim=-1)
        else:
            orders_next_cnt, drivers_next_cnt = torch.zeros_like(orders_now_cnt), torch.zeros_like(drivers_now_cnt)
        
        A_val_final = (~A_val_policy).long().view(BE)
        
        rows = torch.stack([
            (b_grid + 1).reshape(-1), e_grid.reshape(-1), torch.full((BE,), t, device=device),
            orders_now_cnt.reshape(-1), drivers_now_cnt.reshape(-1), A_val_final,
            matched_revenue_per_be.reshape(-1),
            orders_next_cnt.reshape(-1), drivers_next_cnt.reshape(-1),
        ], dim=1)
        all_rows_list.append(rows)

    all_rows = torch.cat(all_rows_list, dim=0)
    df = pd.DataFrame(all_rows.cpu().numpy(), columns=[
        'simu_time', 'n', 'T', 'orders', 'drivers', 'A', 'revenue', 'ordersNext', 'driversNext'
    ])
    df_sorted = df.sort_values(['simu_time', 'n', 'T']).reset_index(drop=True)

    ope_results = {}
    if run_ope:
        print("\n--- Simulation completed, starting Off-Policy Evaluation (OPE) ---")
        B_ope = BATCH_SIZE
        N_ope = NUM_EPISODES * NUMBER_OF_TIME_STEPS
        
        data_for_ope = {}
        ope_columns = ['A', 'revenue', 'orders', 'drivers', 'ordersNext', 'driversNext']
        
        for col in ope_columns:
            tensor_col = torch.tensor(df_sorted[col].values, device=device, dtype=torch.float32)
            data_for_ope[col] = tensor_col.view(B_ope, N_ope)
            
        print("Evaluating MDP policy (A=0)...")
        eta_est_mdp, _, _ = Q_eta_est_poly_tensor_batch(data_for_ope, treatment=0, device=device)
        print("Evaluating Distance policy (A=1)...")
        eta_est_dist, _, _ = Q_eta_est_poly_tensor_batch(data_for_ope, treatment=1, device=device)
        
        ope_results['eta_mdp'] = eta_est_mdp.mean().item()
        ope_results['eta_dist'] = eta_est_dist.mean().item()
        ope_results['ate'] = ope_results['eta_mdp'] - ope_results['eta_dist']
        print("--- OPE evaluation completed ---")

    if run_ope:
        return df_sorted, ope_results
    else:
        return df_sorted

def calculate_dynamic_ate(revenue_df, transition_df):
    T = len(revenue_df)
    
    gamma = np.zeros(T) 
    beta = [np.array(row) for row in revenue_df[['coef_orders', 'coef_drivers']].values]
    Phi = [np.array([[row.orders_coef_orders, row.orders_coef_drivers],
                     [row.drivers_coef_orders, row.drivers_coef_drivers]])
           for _, row in transition_df.iterrows()]
    Gamma = [np.array([row.orders_coef_A, row.drivers_coef_A])
             for _, row in transition_df.iterrows()]

    total_direct_effect = np.sum(gamma)
    total_indirect_effect = 0.0
    
    for t in range(1, T):
        total_impact_on_state_t = np.zeros((2, 1))
        for k in range(1, t):
            propagator = np.identity(2)
            for l in range(t - 1, k, -1):
                propagator = propagator @ Phi[l]
            propagated_shock = propagator @ Gamma[k].reshape(2, 1)
            total_impact_on_state_t += propagated_shock
        total_indirect_effect += (beta[t].reshape(1, 2) @ total_impact_on_state_t).item()
        
    final_ate = 2 * total_direct_effect + 2 * total_indirect_effect
    return final_ate

# ==============================================================================
# 6. Main Execution Script
# ==============================================================================
if __name__ == '__main__':
    start_main = time.time()
    SEED = 42
    set_seed(SEED)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    NUM_EPISODES = 200
    BATCH_SIZE = 100

    try:
        V_FILENAME = 'Value_function_gpu_learned_T10.npz'
        pre_trained_V = torch.from_numpy(np.load(V_FILENAME)['arr_0']).to(device=device, dtype=torch.float32)
        print(f"Successfully loaded pre-trained V-function from '{V_FILENAME}'.")
    except FileNotFoundError:
        print(f"Warning: '{V_FILENAME}' not found. Using zero-initialized V-function.")
        pre_trained_V = torch.zeros(V_DIMS, device=device)
    
    print("\n" + "="*50)
    print("--- Phase 1: Run mixed policy (Random) to generate analysis data ---")
    print("="*50)
    df_for_analysis, ope_results = run_simulation_and_ope(
        BATCH_SIZE=BATCH_SIZE,
        NUM_EPISODES=NUM_EPISODES,
        V_tensor=pre_trained_V,
        device=device,
        method="random",
        run_ope=True
    )

    print("\n" + "="*50)
    print("--- Phase 2: Run dynamic linear regression model ---")
    print("="*50)

    df_analysis = df_for_analysis.copy()
    df_analysis['A'] = df_analysis['A'].replace({0: 1, 1: -1})

    revenue_model_results = []
    transition_model_results = []
    timesteps = sorted(df_analysis['T'].unique())

    for t in timesteps:
        df_t = df_analysis[df_analysis['T'] == t]
        if len(df_t) < 15: continue
        Y_revenue = df_t['revenue']
        X_revenue = sm.add_constant(df_t[['orders', 'drivers', 'A']], has_constant='add')
        results_revenue = sm.OLS(Y_revenue, X_revenue).fit()
        revenue_model_results.append({
            'T': t, 'intercept': results_revenue.params['const'],
            'coef_orders': results_revenue.params['orders'],
            'coef_drivers': results_revenue.params['drivers'],
            'coef_A': results_revenue.params['A'],
            'R_squared': results_revenue.rsquared
        })

        X_transition = sm.add_constant(df_t[['orders', 'drivers', 'A']], has_constant='add')
        results_orders = sm.OLS(df_t['ordersNext'], X_transition).fit()
        results_drivers = sm.OLS(df_t['driversNext'], X_transition).fit()
        transition_model_results.append({
            'T': t, 'orders_intercept': results_orders.params['const'],
            'orders_coef_orders': results_orders.params['orders'],
            'orders_coef_drivers': results_orders.params['drivers'],
            'orders_coef_A': results_orders.params['A'], 'orders_R_squared': results_orders.rsquared,
            'drivers_intercept': results_drivers.params['const'],
            'drivers_coef_orders': results_drivers.params['orders'],
            'drivers_coef_drivers': results_drivers.params['drivers'],
            'drivers_coef_A': results_drivers.params['A'], 'drivers_R_squared': results_drivers.rsquared
        })

    revenue_results_df = pd.DataFrame(revenue_model_results)
    transition_results_df = pd.DataFrame(transition_model_results)

    print("\n" + "="*50)
    print("--- Phase 3: Calculate dynamic ATE ---")
    print("="*50)
    
    dynamic_ate_value = calculate_dynamic_ate(revenue_results_df, transition_results_df)

    print("\n\n" + "="*80)
    print("--- Final Results Summary ---")
    print("="*80)
    
    print("\n--- OPE Results (from Phase 1) ---")
    if ope_results:
        print(f"OPE estimated mean reward for MDP policy: {ope_results['eta_mdp']:.4f}")
        print(f"OPE estimated mean reward for Distance policy: {ope_results['eta_dist']:.4f}")
        print(f"OPE estimated ATE (MDP - Dist): {ope_results['ate']:.4f}")

    print("\n--- Dynamic ATE Results (from Phase 2/3) ---")
    print(f"Final computed dynamic ATE value: {dynamic_ate_value:.4f}")
    
    end_main = time.time()
    total_time = end_main - start_main
    print("\n" + "-"*80)
    print(f"Total elapsed time: {total_time:.2f} seconds")
    print("="*80)
    
    print("\n\n--- Appendix: Detailed regression coefficients ---")
    print("\n--- Revenue Model Regression Results ---")
    print(revenue_results_df.to_string())
    print("\n--- State Transition Model Regression Results ---")
    print(transition_results_df.to_string())

# 4. Convert results list to DataFrame for display
revenue_results_df = pd.DataFrame(revenue_model_results)
transition_results_df = pd.DataFrame(transition_model_results)

print("\n" + "="*50)
print("--- Phase 3: Calculate dynamic ATE ---")
print("="*50)

dynamic_ate_value = calculate_dynamic_ate(revenue_results_df, transition_results_df)

print("\n\n" + "="*80)
print("--- Final Results Summary ---")
print("="*80)

print("\n--- OPE Evaluation Results (from Phase 1) ---")
if ope_results:
    print(f"OPE estimated mean reward for MDP policy: {ope_results['eta_mdp']:.4f}")
    print(f"OPE estimated mean reward for Distance policy: {ope_results['eta_dist']:.4f}")
    print(f"OPE estimated ATE (MDP - Dist): {ope_results['ate']:.4f}")

print("\n--- Dynamic ATE Evaluation Results (from Phase 2/3) ---")
print(f"Final computed dynamic ATE value: {dynamic_ate_value:.4f}")

end_main = time.time()
total_time = end_main - start_main
print("\n" + "-"*80)
print(f"Total elapsed time: {total_time:.2f} seconds")
print("="*80)

print("\n\n--- Appendix: Detailed regression coefficients ---")
print("\n--- Revenue Model Regression Results ---")
print(revenue_results_df.to_string())
print("\n--- State Transition Model Regression Results ---")
print(transition_results_df.to_string())


def Q_eta_est_poly_tensor_batch(data, treatment, device="cuda"):
    """
    Fully tensorized LSTDQ policy evaluator.
    """
    B, N = data['A'].shape
    
    mask = (data['A'] == treatment)
    mask_f = mask.float().unsqueeze(-1)

    revenue = data['revenue'].to(device).unsqueeze(-1)
    orders = data['orders'].to(device).unsqueeze(-1)
    drivers = data['drivers'].to(device).unsqueeze(-1)
    ordersNext = data['ordersNext'].to(device).unsqueeze(-1)
    driversNext = data['driversNext'].to(device).unsqueeze(-1)

    S = torch.cat([orders, drivers], dim=-1)
    Next_S = torch.cat([ordersNext, driversNext], dim=-1)

    # Feature expansion
    phi_S = torch.cat([S, S**2, S**3], dim=-1)
    phi_next_S = torch.cat([Next_S, Next_S**2, Next_S**3], dim=-1)

    # Centering
    mask_sum = mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)
    revenue_mean = (revenue * mask_f).sum(dim=1, keepdim=True) / mask_sum
    revenue_c = (revenue - revenue_mean) * mask_f

    phi_mean = (phi_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_S_c = (phi_S - phi_mean) * mask_f

    phi_next_mean = (phi_next_S * mask_f).sum(dim=1, keepdim=True) / mask_sum
    phi_next_S_c = (phi_next_S - phi_next_mean) * mask_f

    diff_phi_S_c = phi_S_c - phi_next_S_c

    # Solve linear system
    lhs = torch.matmul(diff_phi_S_c.transpose(1, 2), phi_S_c)
    rhs = torch.matmul(phi_S_c.transpose(1, 2), revenue_c)
    
    regularization = 1e-5 * torch.eye(lhs.shape[1], device=lhs.device).unsqueeze(0)
    try:
        beta_a = torch.linalg.solve(lhs + regularization, rhs)
    except torch.linalg.LinAlgError:
        print("Warning: Linear system solve failed, returning zero vector.")
        beta_a = torch.zeros_like(rhs)

    Q_diff_vec = torch.matmul(diff_phi_S_c, beta_a)
    eta_est = ((revenue - Q_diff_vec) * mask_f).sum(dim=1) / mask_sum.squeeze(-1)
    eta_est = eta_est.squeeze(-1)

    TD_error = (revenue - Q_diff_vec - eta_est.unsqueeze(1).unsqueeze(-1)) * mask_f
    TD_error = TD_error.squeeze(-1)
    return eta_est, TD_error, beta_a


df_random, ope_results = run_simulation_and_ope(
    BATCH_SIZE=BATCH_SIZE,
    NUM_EPISODES=NUM_EPISODES,
    V_tensor=pre_trained_V,
    device=device,
    method="random",
    run_ope=True
)
output_filename = "simulation.csv"
df_random.to_csv(output_filename, index=False)
print(f"\n--- Data successfully saved to file: {output_filename} ---")

print("\n--- Results (OPE exact evaluation) ---")
if ope_results:
    print(f"OPE estimated mean reward (eta_mdp): {ope_results['eta_mdp']:.4f}")
    print(f"OPE estimated mean reward (eta_dist): {ope_results['eta_dist']:.4f}")
    print(f"OPE estimated ATE (MDP - Dist): {ope_results['ate']:.4f}")

end_main = time.time()
total_time = end_main - start_main
print("\n" + "="*50)
print(f"Total elapsed time: {total_time:.2f} seconds")
print("="*50)


import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
import logging
import os
import time

# ==============================================================================
# 0. Setup
# ==============================================================================
def setup_logging():
    """Configure logging to output to console and file."""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO) 
    if logger.hasHandlers():
        logger.handlers.clear()
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    
    file_handler = logging.FileHandler('nn_training.log')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

setup_logging()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# ==============================================================================
# 1. Define Neural Network Model
# ==============================================================================
class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=64):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.layers(x)

# ==============================================================================
# 2. Define Training and Evaluation Functions
# ==============================================================================
def train_model(model, train_loader, val_loader, model_name, epochs=100, learning_rate=1e-3):
    """Standard PyTorch training loop with logging, checkpointing, and timing."""
    
    start_time = time.time()
    checkpoint_dir = 'checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)

    model.to(device)
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    logging.info(f"Training model: '{model_name}'...")
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                y_pred = model(X_val)
                val_loss += loss_fn(y_pred, y_val).item()
        val_loss /= len(val_loader)
        logging.info(f'Model: {model_name: <25} | Epoch {epoch+1:03d}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            checkpoint_path = os.path.join(checkpoint_dir, f'{model_name}_epoch_{epoch+1}.pth')
            torch.save(model.state_dict(), checkpoint_path)
            logging.info(f"===> Model checkpoint saved to {checkpoint_path}")

    end_time = time.time()
    elapsed_time = end_time - start_time
    logging.info(f"Model '{model_name}' training complete. Total time: {elapsed_time:.2f} seconds.")

def calculate_r2_torch(model, data_loader):
    """Compute R² scores using PyTorch tensors."""
    model.eval()
    y_true_list, y_pred_list = [], []
    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)
            y_pred_list.append(model(X))
            y_true_list.append(y)
    y_true, y_pred = torch.cat(y_true_list), torch.cat(y_pred_list)
    r2_scores = []
    for i in range(y_true.shape[1]):
        ss_res = torch.mean((y_true[:, i] - y_pred[:, i])**2)
        ss_total = torch.var(y_true[:, i], unbiased=False)
        r2 = 1 - ss_res / ss_total
        r2_scores.append(r2.item())
    return r2_scores

# ==============================================================================
# 3. Main Execution Flow
# ==============================================================================
INPUT_CSV = "simulation.csv"
df = pd.read_csv(INPUT_CSV, encoding="utf-8")

df_random = df
logging.info("Preparing data for neural network training...")

features = ['orders', 'drivers', 'A', 'T']
target_transition = ['ordersNext', 'driversNext']
target_revenue = ['revenue']

X_all = torch.tensor(df_random[features].values, dtype=torch.float32)
y_transition = torch.tensor(df_random[target_transition].values, dtype=torch.float32)
y_revenue = torch.tensor(df_random[target_revenue].values, dtype=torch.float32)

train_size = int(0.8 * len(X_all))
X_train, X_val = X_all[:train_size], X_all[train_size:]
logging.info(f"Data split into training ({train_size} samples) and validation ({len(X_all) - train_size} samples).")
logging.info(f"Model input features: {features}")

logging.info("="*50)
logging.info("--- Task 1: Train state transition model ---")
logging.info(f"Model: ordersNext, driversNext ~ f({', '.join(features)})")

y_train_trans, y_val_trans = y_transition[:train_size], y_transition[train_size:]
train_dataset_trans = TensorDataset(X_train, y_train_trans)
val_dataset_trans = TensorDataset(X_val, y_val_trans)

train_loader_trans = DataLoader(train_dataset_trans, batch_size=512, shuffle=True, num_workers=4, pin_memory=True)
val_loader_trans = DataLoader(val_dataset_trans, batch_size=512, num_workers=4, pin_memory=True)

transition_model = MLP(input_size=4, output_size=2)
train_model(transition_model, train_loader_trans, val_loader_trans, model_name="transition_model_with")

r2_transition = calculate_r2_torch(transition_model, val_loader_trans)
logging.info("--- State Transition Model Evaluation ---")
logging.info(f"  - R² score for ordersNext: {r2_transition[0]:.4f}")
logging.info(f"  - R² score for driversNext: {r2_transition[1]:.4f}")


logging.info("="*50)
logging.info("--- Task 2: Train revenue model ---")
logging.info(f"Model: revenue ~ f({', '.join(features)})")

y_train_rev, y_val_rev = y_revenue[:train_size], y_revenue[train_size:]
train_dataset_rev = TensorDataset(X_train, y_train_rev)
val_dataset_rev = TensorDataset(X_val, y_val_rev)

train_loader_rev = DataLoader(train_dataset_rev, batch_size=512, shuffle=True, num_workers=4, pin_memory=True)
val_loader_rev = DataLoader(val_dataset_rev, batch_size=512, num_workers=4, pin_memory=True)

revenue_model = MLP(input_size=4, output_size=1)
train_model(revenue_model, train_loader_rev, val_loader_rev, model_name="revenue_model_with")

r2_revenue = calculate_r2_torch(revenue_model, val_loader_rev)
logging.info("--- Revenue Model Evaluation ---")
logging.info(f"  - R² score for revenue: {r2_revenue[0]:.4f}")