#!/usr/bin/env python3
import os
import sys
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import logging
from argparse import ArgumentParser
from safetensors.torch import safe_open
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
import json
import glob

# ========= Reproducibility =========
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ========= Logger =========
def setup_logger(save_dir, log_name="train_log.txt", level=logging.INFO):
    os.makedirs(save_dir, exist_ok=True)
    log_path = os.path.join(save_dir, log_name)
    logger = logging.getLogger(f"DualAscentLogger_{save_dir}")
    logger.setLevel(level)
    logger.propagate = False
    logger.handlers.clear()
    formatter = logging.Formatter("%(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S")
    file_handler = logging.FileHandler(log_path, mode="a")
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    return logger

# ========= Args =========
def get_args():
    parser = ArgumentParser()
    parser.add_argument("--model_name_instruct", type=str, required=True)
    parser.add_argument("--model_name_reasoning", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)
    parser.add_argument("--save_dir", type=str, default="../res_robust_multi_budget")
    parser.add_argument("--budgets", type=str, default="1.0,1.2,1.5,2.0,3.0",
                        help="Comma-separated budget values")
    parser.add_argument("--entropy_coef", type=float, default=0.005)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--dual_lr", type=float, default=1e-2)
    parser.add_argument("--dataset_base_path", type=str,
                        default="../prepare_embedding/embeddings/bge-m3")
    parser.add_argument("--replications", type=int, default=5)
    # Robustness
    parser.add_argument("--tau_r", type=float, default=1)
    parser.add_argument("--tau_g", type=float, default=50)
    # OOD eval datasets
    parser.add_argument("--ood_datasets", type=str, default="reward_bench_v2,rm_bench,judgebench")
    # Parallel
    parser.add_argument("--num_gpus", type=int, default=4)
    parser.add_argument("--jobs_per_gpu", type=int, default=5)
    # Policy loading option
    parser.add_argument("--policy_dir", type=str, default=None,
                        help="Path to existing policy directory (e.g., ../res_robust_multi_budget/robust_combined_20251109_141309). If provided, skip training and only do OOD evaluation.")
    parser.add_argument("--delete_policy_after_eval", action="store_true",
                        help="Delete policy model files after OOD evaluation to save disk space")
    # Embedding mode
    parser.add_argument("--use_separate_embeddings", action="store_true",
                        help="Use separate embeddings (prompt + answer_a + answer_b) instead of combined prompt_answer embedding")
    return parser.parse_args()

# ========= Load dataset =========
def load_embedding_dfs(path: str, use_separate_embeddings: bool = False):
    with safe_open(path, framework="pt", device="cpu") as f:
        if use_separate_embeddings:
            embeddings_prompt = f.get_tensor("embeddings_prompt")
            embeddings_answer_a = f.get_tensor("embeddings_answer_a")
            embeddings_answer_b = f.get_tensor("embeddings_answer_b")
            embeddings_combined = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1)
        else:
            embeddings_combined = f.get_tensor("embeddings_prompt_answer")
        
        correct_labels = f.get_tensor("correct_labels")
        consistent_labels = f.get_tensor("consist_labels")
        num_tokens_instruct = f.get_tensor("num_tokens_instruct")
        num_tokens_reasoning = f.get_tensor("num_tokens_reasoning")
        correct_instruct = f.get_tensor("correct_instruct")
        correct_reasoning = f.get_tensor("correct_reasoning")

    df_embed = pd.DataFrame({
        "embedding": [e.tolist() for e in embeddings_combined],
        "correct_label": correct_labels.tolist(),
        "consistent_labels": consistent_labels.tolist(),
        "num_tokens_instruct": num_tokens_instruct.tolist(),
        "num_tokens_reasoning": num_tokens_reasoning.tolist(),
        "token_ratio": (num_tokens_reasoning / (num_tokens_instruct + 1e-8)).tolist(),
        "correct_instruct": correct_instruct.tolist(),
        "correct_reasoning": correct_reasoning.tolist()
    })
    return df_embed


# ========= Model =========
class PolicyNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return torch.sigmoid(self.net(x))

# ========= Eval =========
def evaluate(policy, loader, device):
    policy.eval()
    total_r, total_g, n, total_entropy = 0.0, 0.0, 0, 0.0
    with torch.no_grad():
        for xb, r0, r1, cost_ratio in loader:
            xb, r0, r1, cost_ratio = xb.to(device), r0.to(device), r1.to(device), cost_ratio.to(device)
            r0 = r0.view(-1); r1 = r1.view(-1); cost_ratio = cost_ratio.view(-1)
            pi1 = policy(xb).squeeze(-1)
            pi0 = 1 - pi1
            entropy = - (pi1 * torch.log(pi1 + 1e-8) + pi0 * torch.log(pi0 + 1e-8)).mean()
            reward = pi0 * r0 + pi1 * r1
            cost   = pi0 * 1.0 + pi1 * cost_ratio
            total_r += reward.sum().item()
            total_g += cost.sum().item()
            total_entropy += entropy.item()
            n += xb.size(0)
    return total_r / n, total_g / n, total_entropy / max(len(loader),1)

# ========= Robust Dual Ascent =========
def train_dual_ascent(policy, loaders, budget, entropy_coef, epochs, lr, dual_lr, 
                     tau_r, tau_g, use_robust, device):
    train_loader, val_loader = loaders
    optimizer = optim.AdamW(policy.parameters(), lr=lr)
    lambda_dual = torch.tensor(0.0, device=device)
    B = budget
    H_coef = entropy_coef
    history = []

    for epoch in range(epochs):
        policy.train()
        total_r, total_g, total_entropy = 0.0, 0.0, 0.0

        for xb, r0, r1, cost_ratio in train_loader:
            xb, r0, r1, cost_ratio = xb.to(device), r0.to(device), r1.to(device), cost_ratio.to(device)
            r0 = r0.view(-1); r1 = r1.view(-1); cost_ratio = cost_ratio.view(-1)

            pi1 = policy(xb).squeeze(-1)
            pi0 = 1.0 - pi1
            reward = pi0 * r0 + pi1 * r1
            cost   = pi0 * 1.0 + pi1 * cost_ratio
            entropy = - (pi1 * torch.log(pi1 + 1e-8) + pi0 * torch.log(pi0 + 1e-8)).mean()

            if use_robust:
                # DRO for reward
                s_r = reward.mean().detach()
                w_r = torch.softmax((s_r - reward.detach()) / max(tau_r, 1e-6), dim=0)
                exp_r = (w_r * reward).sum()
                
                exp_g = cost.mean()
            else:
                exp_r = reward.mean()
                exp_g = cost.mean()

            lagrangian = exp_r - lambda_dual * (exp_g - B) + H_coef * entropy + 0.5 * H_coef * (lambda_dual ** 2)

            optimizer.zero_grad()
            (-lagrangian).backward()
            optimizer.step()

            total_r += reward.mean().item()
            total_g += cost.mean().item()
            total_entropy += entropy.item()

        # Calculate averages
        avg_r = total_r / len(train_loader)
        avg_g = total_g / len(train_loader)
        avg_entropy = total_entropy / len(train_loader)

        # Update dual variable
        with torch.no_grad():
            lambda_dual = torch.clamp(
                lambda_dual + dual_lr * (avg_g - B) - dual_lr * H_coef * lambda_dual, 
                min=0.0
            )

        val_r, val_g, val_entropy = evaluate(policy, val_loader, device)

        history.append({
            'epoch': epoch + 1,
            'train_reward': avg_r,
            'train_cost': avg_g,
            'train_entropy': avg_entropy,
            'val_reward': val_r,
            'val_cost': val_g,
            'val_entropy': val_entropy,
            'lambda': lambda_dual.item(),
            'state_dict': {k: v.cpu().clone() for k, v in policy.state_dict().items()}
        })

    valid_models = [h for h in history if h['val_cost'] <= B]
    best_model = max(valid_models, key=lambda h: h['val_reward']) if valid_models else min(history, key=lambda h: abs(h['val_cost'] - B))
    policy.load_state_dict({k: v.to(device) for k, v in best_model['state_dict'].items()})

    return policy, lambda_dual, history, best_model


def expected_eval_policy(policy, df, device):
    X = torch.tensor(np.stack(df['embedding'].values), dtype=torch.float).to(device)
    r0 = torch.tensor(np.array(df['correct_instruct'].tolist()), dtype=torch.float, device=device).view(-1)
    r1 = torch.tensor(np.array(df['correct_reasoning'].tolist()), dtype=torch.float, device=device).view(-1)
    token_ratio = torch.tensor(np.array(df['token_ratio'].tolist()), dtype=torch.float, device=device).view(-1)
    with torch.no_grad():
        pi1 = policy(X).squeeze(-1)
    pi0 = 1 - pi1
    reward = (pi0 * r0 + pi1 * r1).mean().item()
    cost   = (pi0 * 1.0 + pi1 * token_ratio).mean().item()
    return {
        "policy_reward": reward,
        "policy_cost": cost,
        "p_reasoning": pi1.mean().item(),
        "instruct": r0.mean().item(),
        "reasoning": r1.mean().item(),
        "avg_token_ratio": token_ratio.mean().item(),
    }

# ========= Random baseline =========
def random_sample_eval_on_df(df, p_reasoning: float, device: str, seed: int = 42):
    rng = torch.Generator(device=device)
    rng.manual_seed(seed)

    r0 = torch.tensor(np.array(df['correct_instruct'].tolist()), dtype=torch.float, device=device).view(-1)
    r1 = torch.tensor(np.array(df['correct_reasoning'].tolist()), dtype=torch.float, device=device).view(-1)
    token_ratio = torch.tensor(np.array(df['token_ratio'].tolist()), dtype=torch.float, device=device).view(-1)

    actions = torch.bernoulli(torch.full_like(r0, fill_value=p_reasoning, device=device), generator=rng)
    reward  = (1 - actions) * r0 + actions * r1
    cost    = (1 - actions) * 1.0 + actions * token_ratio

    return reward.mean().item(), cost.mean().item()

def load_df(dataset_base_path, model_name_instruct, model_name_reasoning, dataset_name, use_separate_embeddings):
    path = f"{dataset_base_path}/{model_name_instruct}_{model_name_reasoning}_{dataset_name}.safetensors"
    df = load_embedding_dfs(path, use_separate_embeddings)
    return df

# ========= 🆕 Single training job - 使用共享内存 =========
def train_single_job(job_config):
    budget = job_config['budget']
    rep = job_config['rep']
    gpu_id = job_config['gpu_id']
    args_dict = job_config['args']
    exp_dir = job_config['exp_dir']
    
    # Set device
    device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
    set_seed(42 + rep)
    
    X = args_dict['X_shared']
    r_instruct = args_dict['r_instruct_shared']
    r_reason = args_dict['r_reason_shared']
    token_ratio = args_dict['token_ratio_shared']
    
    # Reconstruct dataset
    dataset = TensorDataset(X, r_instruct, r_reason, token_ratio)
    
    N = len(dataset)
    n_train = int(0.6 * N)
    n_val   = int(0.2 * N)
    n_test  = N - n_train - n_val
    train_set, val_set, test_set = random_split(
        dataset, [n_train, n_val, n_test],
        generator=torch.Generator().manual_seed(42 + rep)
    )
    
    batch_size = args_dict['batch_size']
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=256, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=256, shuffle=False)
    
    results = []
    policies_dir = os.path.join(exp_dir, "policies")
    os.makedirs(policies_dir, exist_ok=True)
    
    use_robust = True
    robust_tag = "robust"
    set_seed(42 + rep)
    
    # Train policy
    policy = PolicyNet(X.shape[1]).to(device)
    policy, lambda_dual, history, best_model = train_dual_ascent(
        policy, (train_loader, val_loader), 
        budget=budget,
        entropy_coef=args_dict['entropy_coef'],
        epochs=args_dict['epochs'],
        lr=args_dict['lr'],
        dual_lr=args_dict['dual_lr'],
        tau_r=args_dict['tau_r'],
        tau_g=args_dict['tau_g'],
        use_robust=use_robust,
        device=device
    )
    
    # Test evaluation
    policy.eval()
    all_pi1, all_r0, all_r1, all_cost_ratio = [], [], [], []
    with torch.no_grad():
        for xb, r0, r1, cost_ratio in test_loader:
            xb, r0, r1, cost_ratio = xb.to(device), r0.to(device), r1.to(device), cost_ratio.to(device)
            r0 = r0.view(-1); r1 = r1.view(-1); cost_ratio = cost_ratio.view(-1)
            pi1 = policy(xb).squeeze(-1)
            all_pi1.append(pi1.cpu())
            all_r0.append(r0.cpu())
            all_r1.append(r1.cpu())
            all_cost_ratio.append(cost_ratio.cpu())
    
    pi1_all = torch.cat(all_pi1)
    r0_all  = torch.cat(all_r0)
    r1_all  = torch.cat(all_r1)
    cost_ratio_all = torch.cat(all_cost_ratio)
    
    # Expected policy
    reward_exp = ((1 - pi1_all) * r0_all + pi1_all * r1_all).mean().item()
    cost_exp   = ((1 - pi1_all) * 1.0 + pi1_all * cost_ratio_all).mean().item()
    p_reasoning = pi1_all.mean().item()
    
    # Random baseline
    actions_random = torch.bernoulli(torch.full_like(pi1_all, p_reasoning))
    r_random = ((1 - actions_random) * r0_all + actions_random * r1_all).mean().item()
    c_random = ((1 - actions_random) * 1.0 + actions_random * cost_ratio_all).mean().item()
    
    # Save model
    model_save_path = os.path.join(policies_dir, f"policy_{robust_tag}_B{budget:.2f}_rep{rep}.pt")
    torch.save({
        'state_dict': {k: v.cpu().clone() for k, v in policy.state_dict().items()},
        'lambda_dual': lambda_dual.item(),
        'best_epoch': best_model['epoch'],
        'best_val_reward': best_model['val_reward'],
        'best_val_cost': best_model['val_cost'],
        'budget': budget,
        'replication': rep,
        'use_robust': use_robust,
        'tau_r': args_dict['tau_r'],
        'tau_g': args_dict['tau_g'],
    }, model_save_path)
    
    # Policy
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': args_dict['dataset_name'],
        'method': 'RACER',
        'use_robust': use_robust,
        'acc': reward_exp,
        'cost': cost_exp,
        'p_reasoning': p_reasoning,
        'model_path': model_save_path
    })
    
    # Random
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': args_dict['dataset_name'],
        'method': 'random',
        'use_robust': use_robust,
        'acc': r_random,
        'cost': c_random,
        'p_reasoning': p_reasoning,
    })
    
    # Baselines
    instruct_acc = r0_all.mean().item()
    instruct_cost = 1.0
    reasoning_acc = r1_all.mean().item()
    reasoning_cost = cost_ratio_all.mean().item()
    
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': args_dict['dataset_name'],
        'method': 'instruct',
        'use_robust': None,
        'acc': instruct_acc,
        'cost': instruct_cost,
        'p_reasoning': 0.0,
    })
    
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': args_dict['dataset_name'],
        'method': 'reasoning',
        'use_robust': None,
        'acc': reasoning_acc,
        'cost': reasoning_cost,
        'p_reasoning': 1.0,
    })
    
    return results

# ========= OOD evaluation job =========
def eval_ood_job(eval_config):
    budget = eval_config['budget']
    rep = eval_config['rep']
    gpu_id = eval_config['gpu_id']
    model_path_robust = eval_config['model_path_robust']
    args_dict = eval_config['args']
    
    device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
    set_seed(42 + rep)
    
    # Using shared OOD data
    if 'ood_df_shared' in args_dict:
        df_ood = args_dict['ood_df_shared']
    else:
        ood_dataset = eval_config['ood_dataset']
        try:
            ood_data_path = f"{args_dict['dataset_base_path']}/{args_dict['model_name_instruct']}_{args_dict['model_name_reasoning']}_{ood_dataset}.safetensors"
            
            with safe_open(ood_data_path, framework="pt", device="cpu") as f:
                if args_dict.get('use_separate_embeddings', False):
                    embeddings_prompt = f.get_tensor("embeddings_prompt")
                    embeddings_answer_a = f.get_tensor("embeddings_answer_a")
                    embeddings_answer_b = f.get_tensor("embeddings_answer_b")
                    embeddings_combined = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1)
                else:
                    embeddings_combined = f.get_tensor("embeddings_prompt_answer")
                
                correct_instruct = f.get_tensor("correct_instruct").reshape(-1)
                correct_reasoning = f.get_tensor("correct_reasoning").reshape(-1)
                num_instr = f.get_tensor("num_tokens_instruct").reshape(-1)
                num_reason = f.get_tensor("num_tokens_reasoning").reshape(-1)
            
            token_ratio = num_reason.float() / (num_instr.float() + 1e-8)
            
            # Create DataFrame
            df_ood = pd.DataFrame({
                "embedding": [e.tolist() for e in embeddings_combined],
                "correct_instruct": correct_instruct.tolist(),
                "correct_reasoning": correct_reasoning.tolist(),
                "token_ratio": token_ratio.tolist()
            })
            
        except Exception as e:
            print(f"❌ [eval_ood_job] Failed to load OOD dataset {ood_dataset}: {e}")
            import traceback
            traceback.print_exc()
            raise
    
    results = []
    emb_dim = np.array(df_ood['embedding'].iloc[0]).shape[0]
    
    ood_dataset = eval_config.get('ood_dataset', args_dict.get('current_ood_dataset', 'unknown'))
    
    use_robust = True
    model_path = model_path_robust
    
    # Load model
    checkpoint = torch.load(model_path, map_location='cpu')
    policy = PolicyNet(emb_dim).to(device)
    policy.load_state_dict(checkpoint['state_dict'])
    
    # Evaluate
    metrics = expected_eval_policy(policy, df_ood, device)
    p_reasoning = metrics['p_reasoning']
    
    # Random baseline
    rand_acc, rand_cost = random_sample_eval_on_df(df_ood, p_reasoning, device=device, seed=42+rep)
    
    # Policy
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': ood_dataset,
        'method': 'RACER',
        'use_robust': use_robust,
        'acc': metrics['policy_reward'],
        'cost': metrics['policy_cost'],
        'p_reasoning': p_reasoning,
    })
    
    # Random
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': ood_dataset,
        'method': 'random',
        'use_robust': use_robust,
        'acc': rand_acc,
        'cost': rand_cost,
        'p_reasoning': p_reasoning,
    })
    
    # Baselines
    instruct_acc = metrics['instruct']
    instruct_cost = 1.0
    reasoning_acc = metrics['reasoning']
    reasoning_cost = metrics['avg_token_ratio']
    
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': ood_dataset,
        'method': 'instruct',
        'use_robust': None,
        'acc': instruct_acc,
        'cost': instruct_cost,
        'p_reasoning': 0.0,
    })
    
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': ood_dataset,
        'method': 'reasoning',
        'use_robust': None,
        'acc': reasoning_acc,
        'cost': reasoning_cost,
        'p_reasoning': 1.0,
    })
    
    return results

# ========= Test set evaluation job =========
def eval_testset_job(eval_config):
    budget = eval_config['budget']
    rep = eval_config['rep']
    gpu_id = eval_config['gpu_id']
    model_path_robust = eval_config['model_path_robust']
    args_dict = eval_config['args']
    
    device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
    set_seed(42 + rep)
    
    # Using shared training data
    if 'X_shared' in args_dict:
        X = args_dict['X_shared']
        r_instruct = args_dict['r_instruct_shared']
        r_reason = args_dict['r_reason_shared']
        token_ratio = args_dict['token_ratio_shared']
        train_dataset = args_dict['dataset_name']
    else:

        train_dataset = eval_config['train_dataset']
        try:
            train_data_path = f"{args_dict['dataset_base_path']}/{args_dict['model_name_instruct']}_{args_dict['model_name_reasoning']}_{train_dataset}.safetensors"
            
            with safe_open(train_data_path, framework="pt", device="cpu") as f:
                if args_dict.get('use_separate_embeddings', False):
                    embeddings_prompt = f.get_tensor("embeddings_prompt")
                    embeddings_answer_a = f.get_tensor("embeddings_answer_a")
                    embeddings_answer_b = f.get_tensor("embeddings_answer_b")
                    X = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1).float()
                else:
                    X = f.get_tensor("embeddings_prompt_answer").float()
                
                r_instruct = f.get_tensor("correct_instruct").reshape(-1).float()
                r_reason = f.get_tensor("correct_reasoning").reshape(-1).float()
                num_instr = f.get_tensor("num_tokens_instruct").reshape(-1).float()
                num_reason = f.get_tensor("num_tokens_reasoning").reshape(-1).float()
            
            token_ratio = num_reason / (num_instr + 1e-8)
                
        except Exception as e:
            print(f"❌ [eval_testset_job] Failed to load training dataset {train_dataset}: {e}")
            import traceback
            traceback.print_exc()
            raise
    
    dataset = TensorDataset(X, r_instruct, r_reason, token_ratio)
    
    # Reconstruct test set
    N = len(dataset)
    n_train = int(0.6 * N)
    n_val   = int(0.2 * N)
    n_test  = N - n_train - n_val
    _, _, test_set = random_split(
        dataset, [n_train, n_val, n_test],
        generator=torch.Generator().manual_seed(42 + rep)
    )
    
    test_loader = DataLoader(test_set, batch_size=256, shuffle=False)
    
    results = []
    
    # Evaluate robust model
    use_robust = True
    model_path = model_path_robust
    
    # Load model
    checkpoint = torch.load(model_path, map_location='cpu')
    policy = PolicyNet(X.shape[1]).to(device)
    policy.load_state_dict(checkpoint['state_dict'])
    
    # Evaluate on test set
    policy.eval()
    all_pi1, all_r0, all_r1, all_cost_ratio = [], [], [], []
    with torch.no_grad():
        for xb, r0, r1, cost_ratio in test_loader:
            xb, r0, r1, cost_ratio = xb.to(device), r0.to(device), r1.to(device), cost_ratio.to(device)
            r0 = r0.view(-1); r1 = r1.view(-1); cost_ratio = cost_ratio.view(-1)
            pi1 = policy(xb).squeeze(-1)
            all_pi1.append(pi1.cpu())
            all_r0.append(r0.cpu())
            all_r1.append(r1.cpu())
            all_cost_ratio.append(cost_ratio.cpu())
    
    pi1_all = torch.cat(all_pi1)
    r0_all  = torch.cat(all_r0)
    r1_all  = torch.cat(all_r1)
    cost_ratio_all = torch.cat(all_cost_ratio)
    
    # Expected policy
    reward_exp = ((1 - pi1_all) * r0_all + pi1_all * r1_all).mean().item()
    cost_exp   = ((1 - pi1_all) * 1.0 + pi1_all * cost_ratio_all).mean().item()
    p_reasoning = pi1_all.mean().item()
    
    # Random baseline
    actions_random = torch.bernoulli(torch.full_like(pi1_all, p_reasoning))
    r_random = ((1 - actions_random) * r0_all + actions_random * r1_all).mean().item()
    c_random = ((1 - actions_random) * 1.0 + actions_random * cost_ratio_all).mean().item()
    
    # Policy
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': train_dataset,
        'method': 'RACER',
        'use_robust': use_robust,
        'acc': reward_exp,
        'cost': cost_exp,
        'p_reasoning': p_reasoning,
    })
    
    # Random
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': train_dataset,
        'method': 'random',
        'use_robust': use_robust,
        'acc': r_random,
        'cost': c_random,
        'p_reasoning': p_reasoning,
    })
    
    # Baselines
    instruct_acc = r0_all.mean().item()
    instruct_cost = 1.0
    reasoning_acc = r1_all.mean().item()
    reasoning_cost = cost_ratio_all.mean().item()
    
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': train_dataset,
        'method': 'instruct',
        'use_robust': None,
        'acc': instruct_acc,
        'cost': instruct_cost,
        'p_reasoning': 0.0,
    })
    
    results.append({
        'budget': budget,
        'rep': rep,
        'dataset': train_dataset,
        'method': 'reasoning',
        'use_robust': None,
        'acc': reasoning_acc,
        'cost': reasoning_cost,
        'p_reasoning': 1.0,
    })
    
    return results

# ========= Discover policies =========
def discover_policies_from_dir(policy_dir):
    policies_path = os.path.join(policy_dir, "policies")
    if not os.path.exists(policies_path):
        raise FileNotFoundError(f"Policies directory not found: {policies_path}")
    
    policy_files = {}
    budgets_set = set()
    reps_set = set()
    
    robust_tag = 'robust'
    pattern = os.path.join(policies_path, f"policy_{robust_tag}_B*_rep*.pt")
    for fpath in glob.glob(pattern):
        fname = os.path.basename(fpath)
        try:
            parts = fname.replace(f"policy_{robust_tag}_B", "").replace(".pt", "").split("_rep")
            budget = float(parts[0])
            rep = int(parts[1])
            policy_files[(budget, rep, robust_tag)] = fpath
            budgets_set.add(budget)
            reps_set.add(rep)
        except:
            continue
    
    if not policy_files:
        raise ValueError(f"No valid policy files found in {policies_path}")
    
    return sorted(budgets_set), sorted(reps_set), policy_files

# ========= Plot functions =========
def plot_results(results_df, exp_dir, dataset_name, budget_list):
    plt.figure(figsize=(12, 7))
    
    instruct_data = results_df[results_df['method'] == 'instruct']
    reasoning_data = results_df[results_df['method'] == 'reasoning']
    
    inst_acc_mean = instruct_data['acc'].mean() if len(instruct_data) > 0 else None
    inst_cost_mean = instruct_data['cost'].mean() if len(instruct_data) > 0 else None
    reas_acc_mean = reasoning_data['acc'].mean() if len(reasoning_data) > 0 else None
    reas_cost_mean = reasoning_data['cost'].mean() if len(reasoning_data) > 0 else None
    
    # All-instruct
    if inst_acc_mean is not None and inst_cost_mean is not None:
        plt.axhline(y=inst_acc_mean, color='green', linestyle='--', linewidth=1.5, 
                   alpha=0.4, zorder=1)
        plt.axvline(x=inst_cost_mean, color='green', linestyle='--', linewidth=1.5, 
                   alpha=0.4, zorder=1)
        plt.scatter(inst_cost_mean, inst_acc_mean, s=150, color='green', 
                   marker='o', label='All-Instruct', zorder=6, edgecolors='darkgreen', linewidths=2)
    
    # All-reasoning
    if reas_acc_mean is not None and reas_cost_mean is not None:
        plt.axhline(y=reas_acc_mean, color='red', linestyle='--', linewidth=1.5, 
                   alpha=0.4, zorder=1)
        plt.axvline(x=reas_cost_mean, color='red', linestyle='--', linewidth=1.5, 
                   alpha=0.4, zorder=1)
        plt.scatter(reas_cost_mean, reas_acc_mean, s=150, color='red', 
                   marker='s', label='All-Reasoning', zorder=6, edgecolors='darkred', linewidths=2)
    
    racer_points = []
    random_points = []
    
    n_budgets = len(budget_list)
    racer_colors = plt.cm.Blues(np.linspace(0.4, 0.9, n_budgets))
    random_colors = plt.cm.Oranges(np.linspace(0.4, 0.9, n_budgets))
    
    racer_legend_added = False
    random_legend_added = False
    
    for i, budget in enumerate(budget_list):
        # RACER Policy
        racer_data = results_df[(results_df['method'] == 'RACER') & 
                                 (results_df['budget'] == budget)]
        if len(racer_data) > 0:
            racer_acc_mean = racer_data['acc'].mean()
            racer_acc_std = racer_data['acc'].std()
            racer_acc_se = racer_acc_std / np.sqrt(len(racer_data))
            racer_acc_ci = 1.96 * racer_acc_se  # 95% CI
            
            racer_cost_mean = racer_data['cost'].mean()
            racer_cost_std = racer_data['cost'].std()
            racer_cost_se = racer_cost_std / np.sqrt(len(racer_data))
            racer_cost_ci = 1.96 * racer_cost_se  # 95% CI
            
            plt.scatter(racer_cost_mean, racer_acc_mean, s=100, color=racer_colors[i],
                       marker='s', label='RACER' if not racer_legend_added else '',
                       zorder=5, edgecolors='blue', linewidths=1.5)
            racer_legend_added = True
            
            plt.text(racer_cost_mean - 0.05, racer_acc_mean - 0.003, f'{budget:.1f}', 
                    fontsize=9, color=racer_colors[i], fontweight='bold')
            
            racer_points.append((racer_cost_mean, racer_cost_ci, racer_acc_mean, racer_acc_ci, budget))
        
        # Random
        random_data = results_df[(results_df['method'] == 'random') & 
                                 (results_df['budget'] == budget)]
        if len(random_data) > 0:
            rand_acc_mean = random_data['acc'].mean()
            rand_acc_std = random_data['acc'].std()
            rand_acc_se = rand_acc_std / np.sqrt(len(random_data))
            rand_acc_ci = 1.96 * rand_acc_se
            
            rand_cost_mean = random_data['cost'].mean()
            rand_cost_std = random_data['cost'].std()
            rand_cost_se = rand_cost_std / np.sqrt(len(random_data))
            rand_cost_ci = 1.96 * rand_cost_se
            
            plt.scatter(rand_cost_mean, rand_acc_mean, s=100, color=random_colors[i],
                       marker='x', label='Random' if not random_legend_added else '',
                       zorder=4, linewidths=2)
            random_legend_added = True
            
            random_points.append((rand_cost_mean, rand_cost_ci, rand_acc_mean, rand_acc_ci, budget))
    
    # RACER
    if len(racer_points) > 1:
        racer_points_sorted = sorted(racer_points, key=lambda x: x[0])
        costs_sorted = [p[0] for p in racer_points_sorted]
        accs_sorted = [p[2] for p in racer_points_sorted]
        acc_cis_sorted = [p[3] for p in racer_points_sorted]
        
        plt.plot(costs_sorted, accs_sorted, 
                color='blue', linestyle='-', linewidth=2.5, alpha=0.7,
                label='RACER Frontier', zorder=3)
        
        # 95% 置信带
        acc_upper = [acc + ci for acc, ci in zip(accs_sorted, acc_cis_sorted)]
        acc_lower = [acc - ci for acc, ci in zip(accs_sorted, acc_cis_sorted)]
        plt.fill_between(costs_sorted, acc_lower, acc_upper, 
                        color='blue', alpha=0.2, zorder=2)
    
    # Random
    if len(random_points) > 1:
        random_points_sorted = sorted(random_points, key=lambda x: x[0])
        costs_sorted_r = [p[0] for p in random_points_sorted]
        accs_sorted_r = [p[2] for p in random_points_sorted]
        acc_cis_sorted_r = [p[3] for p in random_points_sorted]
        
        plt.plot(costs_sorted_r, accs_sorted_r, 
                color='darkorange', linestyle='-', linewidth=2, alpha=0.5,
                zorder=3)
        
        
        acc_upper_r = [acc + ci for acc, ci in zip(accs_sorted_r, acc_cis_sorted_r)]
        acc_lower_r = [acc - ci for acc, ci in zip(accs_sorted_r, acc_cis_sorted_r)]
        plt.fill_between(costs_sorted_r, acc_lower_r, acc_upper_r, 
                        color='orange', alpha=0.15, zorder=2)
    
    plt.xlabel("Actual Cost", fontsize=13, fontweight='bold')
    plt.ylabel("Accuracy", fontsize=13, fontweight='bold')
   
    dataset_name_map = {
        'reward_bench': 'RewardBench',
        'reward_bench_v2': 'RewardBench 2',
        'judgebench': 'JudgeBench'
    }
    display_name = dataset_name_map.get(dataset_name, dataset_name)
    plt.title(f"Performance on {display_name}", fontsize=15, fontweight='bold')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
    
    plot_path = os.path.join(exp_dir, f"{dataset_name}_error_bars.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"[plot_results] Saved plot to: {plot_path}")
    return plot_path

def plot_results_row(results_dict, exp_dir, filename, budget_list, figsize_per_plot=(2.5, 2.0)):
    import matplotlib.gridspec as gridspec
    from matplotlib.colors import Normalize
    from matplotlib.cm import ScalarMappable
    
    n_datasets = len(results_dict)
    dataset_names = list(results_dict.keys())
    
    width_ratios = [1.2] + [1] * (n_datasets - 1) if dataset_names[0] == "ID" else [1] * n_datasets
    fig = plt.figure(figsize=(figsize_per_plot[0] * n_datasets + 0.5, figsize_per_plot[1]))
    
    gs = gridspec.GridSpec(1, n_datasets, 
                          width_ratios=width_ratios,
                          wspace=0.3, hspace=0)
    
    axes = [fig.add_subplot(gs[0, i]) for i in range(n_datasets)]
    
    all_handles = []
    all_labels = []
    
    scatter_size_base = 20
    line_width = 1.0
    edge_width = 0.7
    title_fontsize = 9
    label_fontsize = 7
    tick_fontsize = 6
    legend_fontsize = 6
    
    for idx, (dataset_name, results_df) in enumerate(results_dict.items()):
        ax = axes[idx]
        
        instruct_data = results_df[results_df['method'] == 'instruct']
        reasoning_data = results_df[results_df['method'] == 'reasoning']
        
        inst_acc_mean = instruct_data['acc'].mean() if len(instruct_data) > 0 else None
        inst_cost_mean = instruct_data['cost'].mean() if len(instruct_data) > 0 else None
        reas_acc_mean = reasoning_data['acc'].mean() if len(reasoning_data) > 0 else None
        reas_cost_mean = reasoning_data['cost'].mean() if len(reasoning_data) > 0 else None
        
        if inst_acc_mean is not None and inst_cost_mean is not None:
            ax.axhline(y=inst_acc_mean, color='green', linestyle='--', linewidth=line_width, 
                      alpha=0.4, zorder=1)
            ax.axvline(x=inst_cost_mean, color='green', linestyle='--', linewidth=line_width, 
                      alpha=0.4, zorder=1)
            h = ax.scatter(inst_cost_mean, inst_acc_mean, s=scatter_size_base*1.2, color='green', 
                          marker='o', zorder=6, edgecolors='darkgreen', linewidths=edge_width,
                          label='All-Instruct')
            if idx == 0:
                all_handles.append(h)
                all_labels.append('All-Instruct')
        
        if reas_acc_mean is not None and reas_cost_mean is not None:
            ax.axhline(y=reas_acc_mean, color='red', linestyle='--', linewidth=line_width, 
                      alpha=0.4, zorder=1)
            ax.axvline(x=reas_cost_mean, color='red', linestyle='--', linewidth=line_width, 
                      alpha=0.4, zorder=1)
            h = ax.scatter(reas_cost_mean, reas_acc_mean, s=scatter_size_base*1.2, color='red', 
                          marker='s', zorder=6, edgecolors='darkred', linewidths=edge_width,
                          label='All-Reasoning')
            if idx == 0:
                all_handles.append(h)
                all_labels.append('All-Reasoning')
        
        racer_points = []
        random_points = []
        
        n_budgets = len(budget_list)
        racer_colors = plt.cm.Blues(np.linspace(0.4, 0.9, n_budgets))
        random_colors = plt.cm.Oranges(np.linspace(0.4, 0.9, n_budgets))
        
        racer_legend_added = False
        random_legend_added = False
        
        for i, budget in enumerate(budget_list):
            racer_data = results_df[(results_df['method'] == 'RACER') & 
                                     (results_df['budget'] == budget)]
            if len(racer_data) > 0:
                racer_acc_mean = racer_data['acc'].mean()
                racer_acc_std = racer_data['acc'].std()
                racer_acc_se = racer_acc_std / np.sqrt(len(racer_data))
                racer_acc_ci = 1.96 * racer_acc_se  # 95% CI
                
                racer_cost_mean = racer_data['cost'].mean()
                racer_cost_std = racer_data['cost'].std()
                racer_cost_se = racer_cost_std / np.sqrt(len(racer_data))
                racer_cost_ci = 1.96 * racer_cost_se  # 95% CI
                
                h = ax.scatter(racer_cost_mean, racer_acc_mean, s=scatter_size_base, color=racer_colors[i],
                              marker='s', zorder=5, edgecolors='blue', linewidths=edge_width*0.8)
                
                if idx == 0 and not racer_legend_added:
                    all_handles.append(h)
                    all_labels.append('RACER')
                    racer_legend_added = True
                
                racer_points.append((racer_cost_mean, racer_cost_ci, racer_acc_mean, racer_acc_ci, budget))
            
            random_data = results_df[(results_df['method'] == 'random') & 
                                     (results_df['budget'] == budget)]
            if len(random_data) > 0:
                rand_acc_mean = random_data['acc'].mean()
                rand_acc_std = random_data['acc'].std()
                rand_acc_se = rand_acc_std / np.sqrt(len(random_data))
                rand_acc_ci = 1.96 * rand_acc_se
                
                rand_cost_mean = random_data['cost'].mean()
                rand_cost_std = random_data['cost'].std()
                rand_cost_se = rand_cost_std / np.sqrt(len(random_data))
                rand_cost_ci = 1.96 * rand_cost_se
                
                h = ax.scatter(rand_cost_mean, rand_acc_mean, s=scatter_size_base, color=random_colors[i],
                              marker='x', zorder=4, linewidths=edge_width)
                
                if idx == 0 and not random_legend_added:
                    all_handles.append(h)
                    all_labels.append('Random')
                    random_legend_added = True
                
                random_points.append((rand_cost_mean, rand_cost_ci, rand_acc_mean, rand_acc_ci, budget))
        
        # RACER
        if len(racer_points) > 1:
            racer_points_sorted = sorted(racer_points, key=lambda x: x[0])
            costs_sorted = [p[0] for p in racer_points_sorted]
            accs_sorted = [p[2] for p in racer_points_sorted]
            acc_cis_sorted = [p[3] for p in racer_points_sorted]
            
            ax.plot(costs_sorted, accs_sorted, 
                   color='blue', linestyle='-', linewidth=line_width*1.5, alpha=0.7,
                   zorder=3)
            
            acc_upper = [acc + ci for acc, ci in zip(accs_sorted, acc_cis_sorted)]
            acc_lower = [acc - ci for acc, ci in zip(accs_sorted, acc_cis_sorted)]
            ax.fill_between(costs_sorted, acc_lower, acc_upper, 
                           color='blue', alpha=0.2, zorder=2,
                           linewidth=0, edgecolor='none')
        
        # Random baseline
        if len(random_points) > 1:
            random_points_sorted = sorted(random_points, key=lambda x: x[0])
            costs_sorted_r = [p[0] for p in random_points_sorted]
            accs_sorted_r = [p[2] for p in random_points_sorted]
            acc_cis_sorted_r = [p[3] for p in random_points_sorted]
            
            ax.plot(costs_sorted_r, accs_sorted_r, 
                   color='darkorange', linestyle='-', linewidth=line_width*1.2, alpha=0.5,
                   zorder=3)
            
            acc_upper_r = [acc + ci for acc, ci in zip(accs_sorted_r, acc_cis_sorted_r)]
            acc_lower_r = [acc - ci for acc, ci in zip(accs_sorted_r, acc_cis_sorted_r)]
            ax.fill_between(costs_sorted_r, acc_lower_r, acc_upper_r, 
                           color='orange', alpha=0.15, zorder=2,
                           linewidth=0, edgecolor='none')
        
        dataset_name_map = {
            'reward_bench': 'RewardBench',
            'reward_bench_v2': 'RewardBench 2',
            'judgebench': 'JudgeBench'
        }

        if dataset_name == "ID":
            ax.set_title("ID (Training)", fontsize=title_fontsize, fontweight='bold')
        elif dataset_name == "ID (Test)":
            ax.set_title("ID (Test)", fontsize=title_fontsize, fontweight='bold')
        else:
            display_name = dataset_name_map.get(dataset_name, dataset_name)
            ax.set_title(f"OOD ({display_name})", fontsize=title_fontsize, fontweight='bold')        
        ax.set_xlabel("Actual Cost", fontsize=label_fontsize)
        ax.set_ylabel("Accuracy", fontsize=label_fontsize)
        ax.grid(True, linestyle='--', alpha=0.3)
        ax.tick_params(labelsize=tick_fontsize, pad=1)
    
    fig.canvas.draw()
    right_ax = axes[-1]
    pos = right_ax.get_position()

    fig.legend(
        all_handles, all_labels,
        loc='upper left',
        bbox_to_anchor=(pos.x1 + 0.01, pos.y1),
        fontsize=legend_fontsize,
        frameon=False
    )
    
    from matplotlib.transforms import Bbox

    fig.canvas.draw()
    legend = fig.legends[0]

    legend_box: Bbox = legend.get_window_extent(
        fig.canvas.get_renderer()
    ).transformed(fig.transFigure.inverted())

    legend_bottom = legend_box.y0

    bar_width = 0.0075
    bar_height = 0.4
    gap = 0.005

    bar_x0 = legend_box.x0 + 0.01
    bar_y0 = legend_bottom - bar_height

    norm = Normalize(vmin=min(budget_list), vmax=max(budget_list))

    racer_cbar_ax = fig.add_axes([
        bar_x0,
        bar_y0,
        bar_width,
        bar_height
    ])
    sm_racer = ScalarMappable(cmap=plt.cm.Blues, norm=norm)
    sm_racer.set_array([])

    cbar_racer = plt.colorbar(sm_racer, cax=racer_cbar_ax, orientation='vertical')
    cbar_racer.set_ticks([])
    cbar_racer.ax.set_yticks([])
    cbar_racer.ax.tick_params(
        left=False, right=False, labelleft=False, labelright=False
    )
    cbar_racer.set_label("")

    random_cbar_ax = fig.add_axes([
        bar_x0 + bar_width + gap,
        bar_y0,
        bar_width,
        bar_height
    ])
    sm_random = ScalarMappable(cmap=plt.cm.Oranges, norm=norm)
    sm_random.set_array([])

    cbar_random = plt.colorbar(sm_random, cax=random_cbar_ax)
    cbar_random.set_label("  Budget", fontsize=legend_fontsize, rotation=0,
                          labelpad=12)
    cbar_random.ax.yaxis.set_label_position("right")
    cbar_random.ax.tick_params(labelsize=tick_fontsize)
    
    plt.tight_layout()
    
    plot_path = os.path.join(exp_dir, filename)
    plt.savefig(plot_path, dpi=600, bbox_inches='tight')
    plt.close()
    print(f"[plot_results_row] Saved combined plot to: {plot_path}")
    return plot_path


# ========= Main =========
def main():
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    
    args = get_args()
    set_seed(42)
    
    # Parse budgets
    budget_list = [float(b.strip()) for b in args.budgets.split(',')]
    
    # Parse OOD datasets
    ood_datasets = [d.strip() for d in args.ood_datasets.split(',') if d.strip()]
    
    # Check if we should load existing policies or train new ones
    if args.policy_dir is not None:
        # ========= MODE: Load existing policies and only do OOD evaluation =========
        logger = setup_logger(args.policy_dir, "ood_evaluation.log")
        logger.info("=" * 70)
        logger.info("📂 LOADING EXISTING POLICIES MODE")
        logger.info("=" * 70)
        logger.info(f"📂 Policy directory: {args.policy_dir}")
        logger.info(f"🔗 Embedding mode: {'Separate (prompt+answer_a+answer_b)' if args.use_separate_embeddings else 'Combined (prompt_answer)'}")
        
        # Discover policies
        try:
            discovered_budgets, discovered_reps, policy_files = discover_policies_from_dir(args.policy_dir)
            logger.info(f"✅ Discovered {len(policy_files)} policy files")
            logger.info(f"💰 Budgets found: {discovered_budgets}")
            logger.info(f"🔄 Replications found: {discovered_reps}")
            
            budget_list = discovered_budgets
            replications = max(discovered_reps) + 1
            
        except Exception as e:
            logger.error(f"❌ Failed to discover policies: {e}")
            return
        
        # Load training dataset once for test set evaluation
        logger.info("=" * 70)
        logger.info("📥 Loading training dataset once for test set evaluation")
        logger.info("=" * 70)
        
        try:
            data_path = f'{args.dataset_base_path}/{args.model_name_instruct}_{args.model_name_reasoning}_{args.dataset_name}.safetensors'
            logger.info(f"📥 Loading from: {data_path}")
            
            # Load data according to mode
            with safe_open(data_path, framework="pt", device="cpu") as f:
                if args.use_separate_embeddings:
                    embeddings_prompt = f.get_tensor("embeddings_prompt")
                    embeddings_answer_a = f.get_tensor("embeddings_answer_a")
                    embeddings_answer_b = f.get_tensor("embeddings_answer_b")
                    X_shared = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1).float()
                    logger.info(f"🔗 Using separate embeddings: prompt ({embeddings_prompt.shape[1]}) + answer_a ({embeddings_answer_a.shape[1]}) + answer_b ({embeddings_answer_b.shape[1]}) = {X_shared.shape[1]}")
                else:
                    X_shared = f.get_tensor("embeddings_prompt_answer").float()
                    logger.info(f"🔗 Using combined embedding: {X_shared.shape[1]}")
                
                r_instruct_shared = f.get_tensor("correct_instruct").reshape(-1).float()
                r_reason_shared = f.get_tensor("correct_reasoning").reshape(-1).float()
                num_instr = f.get_tensor("num_tokens_instruct").reshape(-1).float()
                num_reason = f.get_tensor("num_tokens_reasoning").reshape(-1).float()
            
            X_shared.share_memory_()
            r_instruct_shared.share_memory_()
            r_reason_shared.share_memory_()
            
            token_ratio_shared = num_reason / (num_instr + 1e-8)
            token_ratio_shared.share_memory_()
            
            logger.info(f"✅ Loaded {len(X_shared)} samples with embedding dim {X_shared.shape[1]}")
            logger.info(f"💾 Shared memory size: ~{X_shared.element_size() * X_shared.nelement() / 1024 / 1024:.2f} MB")
            
        except Exception as e:
            logger.error(f"❌ Failed to load training data: {e}")
            import traceback
            traceback.print_exc()
            return
        
        # Setup args dict for evaluation
        args_dict_testset = {
            'dataset_base_path': args.dataset_base_path,
            'model_name_instruct': args.model_name_instruct,
            'model_name_reasoning': args.model_name_reasoning,
            'dataset_name': args.dataset_name,
            'use_separate_embeddings': args.use_separate_embeddings,
            'X_shared': X_shared,
            'r_instruct_shared': r_instruct_shared,
            'r_reason_shared': r_reason_shared,
            'token_ratio_shared': token_ratio_shared,
        }
        
        # Create results directory
        results_dir = os.path.join(args.policy_dir, "ood_results")
        os.makedirs(results_dir, exist_ok=True)
        
        max_workers = args.num_gpus * args.jobs_per_gpu
        
        # ===== Evaluate on Training Dataset Test Set =====
        logger.info("=" * 70)
        logger.info(f"🎯 Evaluating on Training Dataset Test Set: {args.dataset_name}")
        logger.info("=" * 70)
        
        # Create test set evaluation jobs
        testset_jobs = []
        for budget in discovered_budgets:
            for rep in discovered_reps:
                robust_path = policy_files.get((budget, rep, 'robust'))
                if robust_path:
                    gpu_id = (len(testset_jobs) // args.jobs_per_gpu) % args.num_gpus
                    job = {
                        'budget': budget,
                        'rep': rep,
                        'gpu_id': gpu_id,
                        'model_path_robust': robust_path,
                        'train_dataset': args.dataset_name,
                        'args': args_dict_testset,
                    }
                    testset_jobs.append(job)
        
        logger.info(f"📋 Total test set eval jobs: {len(testset_jobs)}")
        
        # Parallel evaluation on test set
        testset_results = []
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(eval_testset_job, job): job for job in testset_jobs}
            for i, future in enumerate(as_completed(futures)):
                job = futures[future]
                try:
                    result = future.result()
                    testset_results.extend(result)
                    logger.info(f"✅ [{i+1}/{len(testset_jobs)}] Test set eval B={job['budget']:.2f}, rep={job['rep']}")
                except Exception as e:
                    logger.error(f"❌ [{i+1}/{len(testset_jobs)}] Test set eval failed B={job['budget']:.2f}, rep={job['rep']}: {e}")
                    import traceback
                    traceback.print_exc()
        
        # Save test set results
        logger.info(f"📊 Collected {len(testset_results)} result entries for test set")
        if testset_results:
            testset_df = pd.DataFrame(testset_results)
            testset_csv_path = os.path.join(results_dir, f"{args.dataset_name}_testset_res.csv")
            testset_df.to_csv(testset_csv_path, index=False)
            logger.info(f"💾 Saved test set results to {testset_csv_path}")
            
            testset_output_dir = os.path.dirname(results_dir)
            try:
                saved_plot_path = plot_results(testset_df, testset_output_dir, f"{args.dataset_name}_testset", budget_list)
                logger.info(f"✅ Successfully saved test set plot to {saved_plot_path}")
            except Exception as e:
                logger.error(f"❌ Failed to save test set plot: {e}")
        
        # ===== OOD Evaluation with existing policies =====
        logger.info("=" * 70)
        logger.info("🔍 Starting OOD Evaluation with Loaded Policies")
        logger.info("=" * 70)
        
        all_results_dict = {}
        if testset_results:
            all_results_dict["ID"] = testset_df
        
        for ood_dataset in ood_datasets:
            logger.info(f"📊 Evaluating on {ood_dataset}")
            
            # Load OOD dataset once for sharing
            logger.info(f"📥 Loading OOD dataset {ood_dataset} once for sharing...")
            try:
                ood_data_path = f"{args.dataset_base_path}/{args.model_name_instruct}_{args.model_name_reasoning}_{ood_dataset}.safetensors"
                
                with safe_open(ood_data_path, framework="pt", device="cpu") as f:
                    if args.use_separate_embeddings:
                        embeddings_prompt = f.get_tensor("embeddings_prompt")
                        embeddings_answer_a = f.get_tensor("embeddings_answer_a")
                        embeddings_answer_b = f.get_tensor("embeddings_answer_b")
                        embeddings_combined = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1)
                    else:
                        embeddings_combined = f.get_tensor("embeddings_prompt_answer")
                    
                    correct_instruct = f.get_tensor("correct_instruct").reshape(-1)
                    correct_reasoning = f.get_tensor("correct_reasoning").reshape(-1)
                    num_instr = f.get_tensor("num_tokens_instruct").reshape(-1)
                    num_reason = f.get_tensor("num_tokens_reasoning").reshape(-1)
                
                token_ratio = num_reason.float() / (num_instr.float() + 1e-8)
                
                # Create DataFrame
                df_ood = pd.DataFrame({
                    "embedding": [e.tolist() for e in embeddings_combined],
                    "correct_instruct": correct_instruct.tolist(),
                    "correct_reasoning": correct_reasoning.tolist(),
                    "token_ratio": token_ratio.tolist()
                })
                
                logger.info(f"✅ Loaded {len(df_ood)} samples from {ood_dataset}")
                logger.info(f"🔗 Combined embedding dim: {embeddings_combined.shape[1]}")
            except Exception as e:
                logger.error(f"❌ Failed to load OOD dataset {ood_dataset}: {e}")
                import traceback
                traceback.print_exc()
                continue
            
            # Prepare args dict for OOD evaluation
            args_dict_ood = {
                'dataset_base_path': args.dataset_base_path,
                'model_name_instruct': args.model_name_instruct,
                'model_name_reasoning': args.model_name_reasoning,
                'use_separate_embeddings': args.use_separate_embeddings,
                'ood_df_shared': df_ood,
                'current_ood_dataset': ood_dataset,
            }
            
            # Create evaluation jobs
            eval_jobs = []
            for budget in discovered_budgets:
                for rep in discovered_reps:
                    robust_path = policy_files.get((budget, rep, 'robust'))
                    if robust_path:
                        gpu_id = (len(eval_jobs) // args.jobs_per_gpu) % args.num_gpus
                        job = {
                            'budget': budget,
                            'rep': rep,
                            'gpu_id': gpu_id,
                            'model_path_robust': robust_path,
                            'ood_dataset': ood_dataset,
                            'args': args_dict_ood
                        }
                        eval_jobs.append(job)
            
            logger.info(f"📋 Total OOD eval jobs for {ood_dataset}: {len(eval_jobs)}")
            
            # Parallel evaluation
            ood_results = []
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                futures = {executor.submit(eval_ood_job, job): job for job in eval_jobs}
                for i, future in enumerate(as_completed(futures)):
                    job = futures[future]
                    try:
                        result = future.result()
                        ood_results.extend(result)
                        logger.info(f"✅ [{i+1}/{len(eval_jobs)}] OOD eval B={job['budget']:.2f}, rep={job['rep']} on {ood_dataset}")
                    except Exception as e:
                        logger.error(f"❌ [{i+1}/{len(eval_jobs)}] OOD eval failed B={job['budget']:.2f}, rep={job['rep']} on {ood_dataset}: {e}")
                        import traceback
                        traceback.print_exc()
            
            # Save OOD results
            if ood_results:
                ood_df = pd.DataFrame(ood_results)
                all_results_dict[ood_dataset] = ood_df
                
                ood_csv_path = os.path.join(results_dir, f"{ood_dataset}_res.csv")
                ood_df.to_csv(ood_csv_path, index=False)
                logger.info(f"💾 Saved OOD results to {ood_csv_path}")
                
                ood_output_dir = os.path.dirname(results_dir)
                try:
                    saved_plot_path = plot_results(ood_df, ood_output_dir, ood_dataset, budget_list)
                    logger.info(f"✅ Successfully saved OOD plot to {saved_plot_path}")
                except Exception as e:
                    logger.error(f"❌ Failed to save plot: {e}")
        
        # Generate combined plot
        if len(all_results_dict) > 0:
            logger.info("=" * 70)
            logger.info("🎨 Creating combined plot with all datasets")
            logger.info("=" * 70)
            
            combined_output_dir = os.path.dirname(results_dir)
            try:
                combined_plot_path = plot_results_row(
                    all_results_dict, 
                    combined_output_dir, 
                    "combined_all_datasets.png",
                    budget_list,
                    figsize_per_plot=(2.5, 2.0)
                )
                logger.info(f"✅ Successfully saved combined plot to {combined_plot_path}")
            except Exception as e:
                logger.error(f"❌ Failed to create combined plot: {e}")
        
        logger.info("=" * 70)
        logger.info("🎉 OOD evaluation completed!")
        logger.info("=" * 70)
        
    else:
        # ========= MODE: Train new policies and do full evaluation =========
        emb_mode_str = "sep" if args.use_separate_embeddings else "comb"
        exp_name = f"dual_{args.dataset_name}_{emb_mode_str}_tr{args.tau_r}_tg{args.tau_g}"
        exp_dir = os.path.join(args.save_dir, exp_name)
        
        # Check existing results to decide whether to skip training/evaluation
        results_dir = os.path.join(exp_dir, "results")
        skip_training = False
        skip_ood_eval = False
        
        if os.path.exists(exp_dir):
            logger = setup_logger(exp_dir, "experiment.log")
            logger.info("=" * 70)
            logger.info("📂 Found existing experiment directory!")
            logger.info("=" * 70)
            logger.info(f"📂 Directory: {exp_dir}")
            
            # Check training results
            train_csv_path = os.path.join(results_dir, f"{args.dataset_name}_res.csv")
            testset_csv_path = os.path.join(results_dir, f"{args.dataset_name}_testset_res.csv")
            
            if os.path.exists(train_csv_path) and os.path.exists(testset_csv_path):
                logger.info(f"✅ Found existing training results: {train_csv_path}")
                logger.info(f"✅ Found existing test set results: {testset_csv_path}")
                skip_training = True
            else:
                if not os.path.exists(train_csv_path):
                    logger.info(f"⚠️  Training results not found, will train")
                if not os.path.exists(testset_csv_path):
                    logger.info(f"⚠️  Test set results not found, will evaluate")
                skip_training = False
            
            # Check OOD results
            ood_datasets = [d.strip() for d in args.ood_datasets.split(',') if d.strip()]
            all_ood_exist = True
            for ood_dataset in ood_datasets:
                ood_csv_path = os.path.join(results_dir, f"{ood_dataset}_res.csv")
                if not os.path.exists(ood_csv_path):
                    logger.info(f"⚠️  OOD results not found for {ood_dataset}")
                    all_ood_exist = False
                    break
            
            if all_ood_exist and len(ood_datasets) > 0:
                logger.info(f"✅ Found all OOD evaluation results")
                skip_ood_eval = True
            else:
                logger.info(f"⚠️  Some OOD results missing, will evaluate")
            
            if skip_training and skip_ood_eval:
                logger.info("=" * 70)
                logger.info("🎨 All results exist! Skipping to plotting...")
                logger.info("=" * 70)
        else:
            os.makedirs(exp_dir, exist_ok=True)
            logger = setup_logger(exp_dir, "experiment.log")
        
        results_dir = os.path.join(exp_dir, "results")
        os.makedirs(results_dir, exist_ok=True)
        
        logger.info(f"📂 Experiment directory: {exp_dir}")
        logger.info(f"🎯 Training dataset: {args.dataset_name}")
        logger.info(f"💰 Budgets: {budget_list}")
        logger.info(f"🔄 Replications: {args.replications}")
        logger.info(f"🔍 OOD datasets: {ood_datasets}")
        logger.info(f"🎛️  DRO parameters: tau_r={args.tau_r}, tau_g={args.tau_g}")
        logger.info(f"🔗 Embedding mode: {'Separate (prompt+answer_a+answer_b)' if args.use_separate_embeddings else 'Combined (prompt_answer)'}")
        logger.info(f"🗑️  Delete policy after eval: {args.delete_policy_after_eval}")
        
        # Save config
        config = vars(args)
        config['budgets_list'] = budget_list
        config['ood_datasets_list'] = ood_datasets
        config_path = os.path.join(exp_dir, "config.json")
        if not os.path.exists(config_path):
            with open(config_path, 'w') as f:
                json.dump(config, f, indent=2)
        
        all_results_dict = {}
        
        # ===== Phase 1: Training or Load Results =====
        if skip_training:
            logger.info("=" * 70)
            logger.info("📥 Phase 1: Loading existing training results")
            logger.info("=" * 70)
            
            train_csv_path = os.path.join(results_dir, f"{args.dataset_name}_res.csv")
            train_df = pd.read_csv(train_csv_path)
            logger.info(f"✅ Loaded {len(train_df)} training result entries from {train_csv_path}")
            all_results_dict["ID"] = train_df
            
            # Load test set results if available
            testset_csv_path = os.path.join(results_dir, f"{args.dataset_name}_testset_res.csv")
            if os.path.exists(testset_csv_path):
                testset_df = pd.read_csv(testset_csv_path)
                logger.info(f"✅ Loaded {len(testset_df)} test set result entries from {testset_csv_path}")
                all_results_dict["ID (Test)"] = testset_df
            else:
                logger.warning(f"⚠️  Test set results not found: {testset_csv_path}")
        else:
            # Load training data once in main process
            data_path = f'{args.dataset_base_path}/{args.model_name_instruct}_{args.model_name_reasoning}_{args.dataset_name}.safetensors'
            logger.info("=" * 70)
            logger.info("🚀 Phase 1: Training policies (RACER only)")
            logger.info("=" * 70)
            logger.info(f"📥 Loading training data from: {data_path}")
            logger.info("🔑 Loading data ONCE in main process for shared memory...")
            
            # Load data into shared memory
            with safe_open(data_path, framework="pt", device="cpu") as f:
                if args.use_separate_embeddings:
                    embeddings_prompt = f.get_tensor("embeddings_prompt")
                    embeddings_answer_a = f.get_tensor("embeddings_answer_a")
                    embeddings_answer_b = f.get_tensor("embeddings_answer_b")
                    X_shared = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1).float()
                    logger.info(f"🔗 Using separate embeddings: prompt ({embeddings_prompt.shape[1]}) + answer_a ({embeddings_answer_a.shape[1]}) + answer_b ({embeddings_answer_b.shape[1]}) = {X_shared.shape[1]}")
                else:
                    X_shared = f.get_tensor("embeddings_prompt_answer").float()
                    logger.info(f"🔗 Using combined embedding: {X_shared.shape[1]}")
                
                r_instruct_shared = f.get_tensor("correct_instruct").reshape(-1).float()
                r_reason_shared = f.get_tensor("correct_reasoning").reshape(-1).float()
                num_instr = f.get_tensor("num_tokens_instruct").reshape(-1).float()
                num_reason = f.get_tensor("num_tokens_reasoning").reshape(-1).float()
            
            X_shared.share_memory_()
            r_instruct_shared.share_memory_()
            r_reason_shared.share_memory_()
            
            token_ratio_shared = num_reason / (num_instr + 1e-8)
            token_ratio_shared.share_memory_()
            
            logger.info(f"✅ Loaded {len(X_shared)} samples with embedding dim {X_shared.shape[1]}")
            logger.info(f"💾 Shared memory size: ~{X_shared.element_size() * X_shared.nelement() / 1024 / 1024:.2f} MB")
            
            # Prepare args dict for shared data
            args_dict = {
                'batch_size': args.batch_size,
                'entropy_coef': args.entropy_coef,
                'epochs': args.epochs,
                'lr': args.lr,
                'dual_lr': args.dual_lr,
                'tau_r': args.tau_r,
                'tau_g': args.tau_g,
                'dataset_name': args.dataset_name,
                'dataset_base_path': args.dataset_base_path,
                'model_name_instruct': args.model_name_instruct,
                'model_name_reasoning': args.model_name_reasoning,
                'use_separate_embeddings': args.use_separate_embeddings,
                'X_shared': X_shared,
                'r_instruct_shared': r_instruct_shared,
                'r_reason_shared': r_reason_shared,
                'token_ratio_shared': token_ratio_shared,
            }
            
            train_jobs = []
            for budget in budget_list:
                for rep in range(args.replications):
                    gpu_id = (len(train_jobs) // args.jobs_per_gpu) % args.num_gpus
                    job = {
                        'budget': budget,
                        'rep': rep,
                        'gpu_id': gpu_id,
                        'args': args_dict,
                        'exp_dir': exp_dir
                    }
                    train_jobs.append(job)
            
            logger.info(f"📋 Total training jobs: {len(train_jobs)}")
            logger.info(f"🖥️  Using {args.num_gpus} GPUs with {args.jobs_per_gpu} jobs per GPU")
            logger.info(f"🔢 Max workers: {args.num_gpus * args.jobs_per_gpu}")
            
            # Parallel training
            train_results = []
            max_workers = args.num_gpus * args.jobs_per_gpu
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                futures = {executor.submit(train_single_job, job): job for job in train_jobs}
                for i, future in enumerate(as_completed(futures)):
                    job = futures[future]
                    try:
                        result = future.result()
                        train_results.extend(result)
                        logger.info(f"✅ [{i+1}/{len(train_jobs)}] Completed B={job['budget']:.2f}, rep={job['rep']}")
                    except Exception as e:
                        logger.error(f"❌ [{i+1}/{len(train_jobs)}] Failed B={job['budget']:.2f}, rep={job['rep']}: {e}")
                        import traceback
                        traceback.print_exc()
            
            # Save training results
            train_df = pd.DataFrame(train_results)
            train_csv_path = os.path.join(results_dir, f"{args.dataset_name}_res.csv")
            train_df.to_csv(train_csv_path, index=False)
            logger.info(f"💾 Saved training results to {train_csv_path}")
            
            # Plot training results
            train_plot_path = plot_results(train_df, exp_dir, args.dataset_name, budget_list)
            logger.info(f"📊 Saved training plot to {train_plot_path}")
            
            all_results_dict["ID"] = train_df
            
            # 🆕 ===== Evaluate on Test Set =====
            logger.info("=" * 70)
            logger.info("🎯 Evaluating trained policies on Test Set")
            logger.info("=" * 70)
            
            # Prepare test set evaluation jobs
            testset_jobs = []
            for budget in budget_list:
                for rep in range(args.replications):
                    robust_path = os.path.join(exp_dir, "policies", f"policy_robust_B{budget:.2f}_rep{rep}.pt")
                    if os.path.exists(robust_path):
                        gpu_id = (len(testset_jobs) // args.jobs_per_gpu) % args.num_gpus
                        job = {
                            'budget': budget,
                            'rep': rep,
                            'gpu_id': gpu_id,
                            'model_path_robust': robust_path,
                            'train_dataset': args.dataset_name,
                            'args': args_dict
                        }
                        testset_jobs.append(job)
            
            logger.info(f"📋 Total test set eval jobs: {len(testset_jobs)}")
            
            # Parallel evaluation on test set
            testset_results = []
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                futures = {executor.submit(eval_testset_job, job): job for job in testset_jobs}
                for i, future in enumerate(as_completed(futures)):
                    job = futures[future]
                    try:
                        result = future.result()
                        testset_results.extend(result)
                        logger.info(f"✅ [{i+1}/{len(testset_jobs)}] Test set eval B={job['budget']:.2f}, rep={job['rep']}")
                    except Exception as e:
                        logger.error(f"❌ [{i+1}/{len(testset_jobs)}] Test set eval failed B={job['budget']:.2f}, rep={job['rep']}: {e}")
                        import traceback
                        traceback.print_exc()
            
            # Save test set results
            if testset_results:
                testset_df = pd.DataFrame(testset_results)
                testset_csv_path = os.path.join(results_dir, f"{args.dataset_name}_testset_res.csv")
                testset_df.to_csv(testset_csv_path, index=False)
                logger.info(f"💾 Saved test set results to {testset_csv_path}")
                
                # Plot test set results
                try:
                    testset_plot_path = plot_results(testset_df, exp_dir, f"{args.dataset_name}_testset", budget_list)
                    logger.info(f"📊 Saved test set plot to {testset_plot_path}")
                except Exception as e:
                    logger.error(f"❌ Failed to save test set plot: {e}")
                
                all_results_dict["ID (Test)"] = testset_df
            else:
                logger.warning("⚠️  No test set results collected")
        
        # ===== Phase 2: OOD Evaluation or Load Results =====
        if skip_ood_eval:
            logger.info("=" * 70)
            logger.info("📥 Phase 2: Loading existing OOD evaluation results")
            logger.info("=" * 70)
            
            for ood_dataset in ood_datasets:
                ood_csv_path = os.path.join(results_dir, f"{ood_dataset}_res.csv")
                if os.path.exists(ood_csv_path):
                    ood_df = pd.read_csv(ood_csv_path)
                    all_results_dict[ood_dataset] = ood_df
                    logger.info(f"✅ Loaded {len(ood_df)} result entries from {ood_csv_path}")
                else:
                    logger.warning(f"⚠️  File not found: {ood_csv_path}")
        else:
            logger.info("=" * 70)
            logger.info("🔍 Phase 2: OOD Evaluation")
            logger.info("=" * 70)
            
            max_workers = args.num_gpus * args.jobs_per_gpu
            
            for ood_dataset in ood_datasets:
                logger.info(f"📊 Evaluating on {ood_dataset}")
                
                # Load OOD data once in main process
                logger.info(f"📥 Loading OOD dataset {ood_dataset} once for sharing...")
                try:
                    ood_data_path = f"{args.dataset_base_path}/{args.model_name_instruct}_{args.model_name_reasoning}_{ood_dataset}.safetensors"
                    
                    with safe_open(ood_data_path, framework="pt", device="cpu") as f:
                        if args.use_separate_embeddings:
                            embeddings_prompt = f.get_tensor("embeddings_prompt")
                            embeddings_answer_a = f.get_tensor("embeddings_answer_a")
                            embeddings_answer_b = f.get_tensor("embeddings_answer_b")
                            embeddings_combined = torch.cat([embeddings_prompt, embeddings_answer_a, embeddings_answer_b], dim=1)
                        else:
                            embeddings_combined = f.get_tensor("embeddings_prompt_answer")
                        
                        correct_instruct = f.get_tensor("correct_instruct").reshape(-1)
                        correct_reasoning = f.get_tensor("correct_reasoning").reshape(-1)
                        num_instr = f.get_tensor("num_tokens_instruct").reshape(-1)
                        num_reason = f.get_tensor("num_tokens_reasoning").reshape(-1)
                    
                    token_ratio = num_reason.float() / (num_instr.float() + 1e-8)
                    
                    # Create DataFrame
                    df_ood = pd.DataFrame({
                        "embedding": [e.tolist() for e in embeddings_combined],
                        "correct_instruct": correct_instruct.tolist(),
                        "correct_reasoning": correct_reasoning.tolist(),
                        "token_ratio": token_ratio.tolist()
                    })
                    
                    logger.info(f"✅ Loaded {len(df_ood)} samples from {ood_dataset}")
                except Exception as e:
                    logger.error(f"❌ Failed to load OOD dataset {ood_dataset}: {e}")
                    import traceback
                    traceback.print_exc()
                    continue
                
                # Prepare args_dict for OOD evaluation
                args_dict_ood = {
                    'dataset_base_path': args.dataset_base_path,
                    'model_name_instruct': args.model_name_instruct,
                    'model_name_reasoning': args.model_name_reasoning,
                    'use_separate_embeddings': args.use_separate_embeddings,
                    'ood_df_shared': df_ood, 
                    'current_ood_dataset': ood_dataset, 
                }
                
                # Create evaluation jobs
                eval_jobs = []
                for budget in budget_list:
                    for rep in range(args.replications):
                        robust_path = os.path.join(exp_dir, "policies", f"policy_robust_B{budget:.2f}_rep{rep}.pt")
                        if os.path.exists(robust_path):
                            gpu_id = (len(eval_jobs) // args.jobs_per_gpu) % args.num_gpus
                            job = {
                                'budget': budget,
                                'rep': rep,
                                'gpu_id': gpu_id,
                                'model_path_robust': robust_path,
                                'ood_dataset': ood_dataset,
                                'args': args_dict_ood
                            }
                            eval_jobs.append(job)
                
                logger.info(f"📋 Total OOD eval jobs for {ood_dataset}: {len(eval_jobs)}")
                
                # Parallel evaluation
                ood_results = []
                with ProcessPoolExecutor(max_workers=max_workers) as executor:
                    futures = {executor.submit(eval_ood_job, job): job for job in eval_jobs}
                    for i, future in enumerate(as_completed(futures)):
                        job = futures[future]
                        try:
                            result = future.result()
                            ood_results.extend(result)
                            logger.info(f"✅ [{i+1}/{len(eval_jobs)}] OOD eval B={job['budget']:.2f}, rep={job['rep']} on {ood_dataset}")
                        except Exception as e:
                            logger.error(f"❌ [{i+1}/{len(eval_jobs)}] OOD eval failed B={job['budget']:.2f}, rep={job['rep']} on {ood_dataset}: {e}")
                            import traceback
                            traceback.print_exc()
                
                # Save OOD results
                if ood_results:
                    ood_df = pd.DataFrame(ood_results)
                    all_results_dict[ood_dataset] = ood_df
                    
                    ood_csv_path = os.path.join(results_dir, f"{ood_dataset}_res.csv")
                    ood_df.to_csv(ood_csv_path, index=False)
                    logger.info(f"💾 Saved OOD results to {ood_csv_path}")
                    
                    try:
                        ood_plot_path = plot_results(ood_df, results_dir, ood_dataset, budget_list)
                        logger.info(f"📊 Saved OOD plot to {ood_plot_path}")
                    except Exception as e:
                        logger.error(f"❌ Failed to save plot for {ood_dataset}: {e}")
                        import traceback
                        traceback.print_exc()
                else:
                    logger.warning(f"⚠️  No results collected for {ood_dataset}")
        
        # Phase 3: Combined row plot
        if len(all_results_dict) > 0:
            logger.info("=" * 70)
            logger.info("🎨 Creating combined plot with ID + OOD datasets")
            logger.info("=" * 70)
            try:
                combined_plot_path = plot_results_row(
                    all_results_dict,
                    exp_dir,
                    "combined_all_datasets.png",
                    budget_list,
                    figsize_per_plot=(2.5, 2.0)
                )
                logger.info(f"✅ Saved combined plot to {combined_plot_path}")
            except Exception as e:
                logger.error(f"❌ Failed to create combined plot: {e}")
                import traceback
                traceback.print_exc()
        
        # Phase 4: Optionally delete policy files to save disk space
        if args.delete_policy_after_eval:
            logger.info("=" * 70)
            logger.info("🗑️  Deleting policy files to save disk space")
            logger.info("=" * 70)
            
            policies_dir = os.path.join(exp_dir, "policies")
            if os.path.exists(policies_dir):
                import shutil
                try:
                    total_size = 0
                    num_files = 0
                    for root, dirs, files in os.walk(policies_dir):
                        for file in files:
                            if file.endswith('.pt'):
                                filepath = os.path.join(root, file)
                                total_size += os.path.getsize(filepath)
                                num_files += 1
                    
                    # Delete the directory
                    shutil.rmtree(policies_dir)
                    
                    logger.info(f"✅ Deleted {num_files} policy files")
                    logger.info(f"💾 Freed ~{total_size / 1024 / 1024:.2f} MB of disk space")
                except Exception as e:
                    logger.error(f"❌ Failed to delete policy files: {e}")
            else:
                logger.warning("⚠️  Policies directory not found, nothing to delete")
        
        logger.info("=" * 70)
        logger.info("🎉 All experiments completed!")
        logger.info("=" * 70)


if __name__ == "__main__":
    main()