import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
import torch
import argparse
import pandas as pd
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from loaders.s_loader import S_Loader
from environments.environment import Env  # Use Env, not Environment
from sessions.cb_session import CB_Session
import numpy as np

# ===================== Argument Parsing =====================
def parse_args():
    parser = argparse.ArgumentParser(
        description="Evaluate Decision Maker with Gaussian Perturbation (mu_p, sigma_p)"
    )
    parser.add_argument(
        "--sim_ckpt",
        type=str,
        required=True,
        help="Path to trained predictor checkpoint",
    )
    parser.add_argument(
        "--dm_ckpt",
        type=str,
        required=True,
        help="Path to perturbed decision maker checkpoint",
    )
    parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
    parser.add_argument(
        "--batch_size", type=int, default=256, help="Batch size for evaluation"
    )
    # Model settings
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    parser.add_argument("--decay_coeff", type=float, default=0.00001)
    parser.add_argument("--dm_decay_coeff", type=float, default=0.0005)
    parser.add_argument("--mi_coeff", type=float, default=10)
    parser.add_argument("--ma_coeff", type=float, default=1)
    parser.add_argument("--otr_reward_coeff", type=float, default=1)
    parser.add_argument("--reward_smoothing_factor", type=float, default=0.5)
    parser.add_argument("--mip_coeff", type=float, default=1)
    parser.add_argument("--mil_coeff", type=float, default=1)
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--dm_lr", type=float, default=0.01)
    parser.add_argument("--use_calibration", type=bool, default=False)
    parser.add_argument("--use_perturbation", type=bool, default=False)
    parser.add_argument("--eta_c", type=float, default=0.0)
    parser.add_argument("--eta_p", type=float, default=0.0)
    parser.add_argument(
        "--dm_epochs", type=int, default=1, help="Dummy for compatibility"
    )
    parser.add_argument(
        "--eva_interval", type=int, default=1, help="Dummy for compatibility"
    )
    parser.add_argument(
        "--early_stop", type=int, default=10, help="Dummy for compatibility"
    )
    parser.add_argument(
        "--suffix", type=str, default="eval", help="Suffix for checkpoint naming"
    )
    parser.add_argument(
        "--CKPT_PATH", type=str, default="./", help="Checkpoint directory"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available")
    parser.add_argument(
        "--device_id", type=str, default="cuda:0", help="Device id for torch"
    )
    parser.add_argument('--epsilon_p', type=float, default=0.0)
    parser.add_argument("--random_noise", type=bool, default=True)
    return parser.parse_args()

# ===================== Main Evaluation Script =====================
def main():
    args = parse_args()
    args.epochs = 0  # For compatibility
    args.ckpt_start_epoch = 0  # For compatibility
    args.wandb = False  # Disable wandb logging for evaluation

    # Set up environment and dataset
    env = Env(args)
    dataset = S_Loader(env)
    env.feature_classes = dataset.feature_classes

    # Build and load predictor (S_SimDec)
    model = S_SimDec(env).to(env.device)
    model.load_state_dict(torch.load(args.sim_ckpt, map_location=env.device))
    model.eval()

    # Build and load unperturbed decision maker (ValueNetwork)
    value_network = ValueNetwork(env).to(env.device)
    value_network.load_state_dict(
        torch.load(args.dm_ckpt, map_location=env.device)
    )
    value_network.eval()

    # Set up CB_Session for perturbed
    session = CB_Session(env, model, dataset)
    session.init_value_network(value_network)
    # ===== Ensure StandardScaler is fitted before evaluation =====
    if hasattr(session, "scaler") and hasattr(
        session, "train_inputs"
    ):
        session.scaler.fit(session.train_inputs)
    reward_dict = {}

    # Step 1: 测试 clean baseline（ε = 0.0）
    epsilon_clean = 0.0
    print(f"Testing epsilon = {epsilon_clean}")
    profit, on_time_ratio, _, _, reward_df = session.dm_test(
        mode="test",
        epsilon_p=epsilon_clean,
        random_noise=args.random_noise,
        return_detail=True
    )
    reward_df["reward"] = reward_df["on_time"] + reward_df["profit"]
    reward_list = reward_df["reward"].tolist()
    avg_reward = np.mean(reward_list)
    reward_dict[epsilon_clean] = {
        "avg_reward": avg_reward,
        "reward_list": reward_list
    }

    # Step 2: 测试扰动（ε = 0.5）
    epsilon = 0.5
    print(f"Testing epsilon = {epsilon}")
    profit, on_time_ratio, _, _, reward_df_p = session.dm_test(
        mode="test",
        epsilon_p=epsilon,
        random_noise=args.random_noise,
        return_detail=True
    )
    reward_df_p["reward"] = reward_df_p["on_time"] + reward_df_p["profit"]
    reward_list_p = reward_df_p["reward"].tolist()
    avg_reward_p = np.mean(reward_list_p)
    reward_dict[epsilon] = {
        "avg_reward": avg_reward_p,
        "reward_list": reward_list_p
    }

    # Step 3: 计算鲁棒性指标
    clean_reward = reward_dict[0.0]["avg_reward"]
    perturbed_reward_list = reward_dict[0.5]["reward_list"]
    avg_perturbed_reward = reward_dict[0.5]["avg_reward"]
    worst_case_reward = min(perturbed_reward_list)
    reward_variance = np.var(perturbed_reward_list)
    drop_rate = clean_reward - avg_perturbed_reward

    # Step 4: Recovery rate
    gt_rewards = np.array(reward_dict[0.0]["reward_list"])
    perturbed_rewards = np.array(perturbed_reward_list)
    recovered_count = np.sum(perturbed_rewards > gt_rewards)
    recovery_rate = recovered_count / len(perturbed_rewards)

    # 打印结果
    print("\n===== ROBUSTNESS SUMMARY (ε = 0.5) =====")
    print(f"Clean Reward (ε=0):               {clean_reward:.4f}")
    print(f"Perturbed Reward (ε=0.5):         {avg_perturbed_reward:.4f}")
    print(f"Worst-case Reward:                {worst_case_reward:.4f}")
    print(f"Reward Variance (Stability ↓):   {reward_variance:.6f}")
    print(f"Drop Rate (ΔReward):              {drop_rate:.4f}")
    print(f"Recovery Rate vs Ground Truth:    {recovery_rate:.4f}")

if __name__ ==  "__main__":
    main()