import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
import torch
import argparse
import pandas as pd
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from models.recalibrator import Recalibrator
from loaders.s_loader import S_Loader
from environments.environment import Env  # Use Env, not Environment
from sessions.cb_session import CB_Session

# ===================== Argument Parsing =====================
def parse_args():
    parser = argparse.ArgumentParser(
        description="Evaluate Decision Maker with Gaussian Perturbation (mu_p, sigma_p)"
    )
    parser.add_argument(
        "--sim_ckpt",
        type=str,
        required=True,
        help="Path to trained predictor checkpoint",
    )
    parser.add_argument(
        "--dm_ckpt_perturbed",
        type=str,
        required=True,
        help="Path to perturbed decision maker checkpoint",
    )
    parser.add_argument(
        "--dm_ckpt_unperturbed",
        type=str,
        required=True,
        help="Path to unperturbed decision maker checkpoint",
    )
    parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
    parser.add_argument(
        "--batch_size", type=int, default=256, help="Batch size for evaluation"
    )
    # Model settings
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    parser.add_argument("--decay_coeff", type=float, default=0.00001)
    parser.add_argument("--dm_decay_coeff", type=float, default=0.0005)
    parser.add_argument("--mi_coeff", type=float, default=10)
    parser.add_argument("--ma_coeff", type=float, default=1)
    parser.add_argument("--otr_reward_coeff", type=float, default=1)
    parser.add_argument("--reward_smoothing_factor", type=float, default=0.5)
    parser.add_argument("--mip_coeff", type=float, default=1)
    parser.add_argument("--mil_coeff", type=float, default=1)
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--dm_lr", type=float, default=0.01)
    parser.add_argument("--use_calibration", type=bool, default=False)
    parser.add_argument("--use_perturbation", type=bool, default=False)
    parser.add_argument("--eta_c", type=float, default=0.0)
    parser.add_argument("--eta_p", type=float, default=0.0)
    parser.add_argument(
        "--dm_epochs", type=int, default=1, help="Dummy for compatibility"
    )
    parser.add_argument(
        "--eva_interval", type=int, default=1, help="Dummy for compatibility"
    )
    parser.add_argument(
        "--early_stop", type=int, default=10, help="Dummy for compatibility"
    )
    parser.add_argument(
        "--suffix", type=str, default="eval", help="Suffix for checkpoint naming"
    )
    parser.add_argument(
        "--CKPT_PATH", type=str, default="./", help="Checkpoint directory"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available")
    parser.add_argument(
        "--device_id", type=str, default="cuda:0", help="Device id for torch"
    )
    parser.add_argument('--epsilon_p', type=float, default=0.0)
    parser.add_argument("--random_noise", type=bool, default=False)
    return parser.parse_args()



# ===================== Main Evaluation Script =====================
def main():
    args = parse_args()
    args.epochs = 0  # For compatibility
    args.ckpt_start_epoch = 0  # For compatibility
    args.wandb = False  # Disable wandb logging for evaluation

    # Set up environment and dataset
    env = Env(args)
    dataset = S_Loader(env)
    env.feature_classes = dataset.feature_classes

    # Build and load predictor (S_SimDec)
    model = S_SimDec(env).to(env.device)
    model.load_state_dict(torch.load(args.sim_ckpt, map_location=env.device))
    model.eval()

    # Build and load perturbed decision maker (ValueNetwork)
    value_network_perturbed = ValueNetwork(env).to(env.device)
    value_network_perturbed.load_state_dict(
        torch.load(args.dm_ckpt_perturbed, map_location=env.device)
    )
    value_network_perturbed.eval()

    # Build and load unperturbed decision maker (ValueNetwork)
    value_network_unperturbed = ValueNetwork(env).to(env.device)
    value_network_unperturbed.load_state_dict(
        torch.load(args.dm_ckpt_unperturbed, map_location=env.device)
    )
    value_network_unperturbed.eval()

    # Set up CB_Session for perturbed
    session_perturbed = CB_Session(env, model, dataset)
    session_perturbed.init_value_network(value_network_perturbed)
    # Patch perturbator for Gaussian(mu_p, sigma_p)
    # if session_perturbed.perturbator is not None:
    #     patch_sample_perturbations(session_perturbed.perturbator, args.mu_p, args.sigma_p)

    # Set up CB_Session for unperturbed
    session_unperturbed = CB_Session(env, model, dataset)
    session_unperturbed.init_value_network(value_network_unperturbed)
    # For unperturbed, do not use perturbation (set use_perturbation to False)
    session_unperturbed.use_perturbation = False
    session_unperturbed.perturbator = None

    # ===== Ensure StandardScaler is fitted before evaluation =====
    if hasattr(session_perturbed, "scaler") and hasattr(
        session_perturbed, "train_inputs"
    ):
        session_perturbed.scaler.fit(session_perturbed.train_inputs)
    if hasattr(session_unperturbed, "scaler") and hasattr(
        session_unperturbed, "train_inputs"
    ):
        session_unperturbed.scaler.fit(session_unperturbed.train_inputs)

    # Evaluate both decision makers on validation and test sets
    for mode in ["val", "test"]:
        print(
            f"\nEvaluating Decision Maker (PERTURBED, epsilon_p = {args.epsilon_p}) in {mode.upper()}..."
        )
        profit_p, on_time_ratio_p, profit_min_percent_p, test_time_p, record_df = (
            session_perturbed.dm_test(
                mode, epsilon_p=args.epsilon_p, random_noise=args.random_noise,
                return_detail=True,
            )
        )  # epsilon_p is dummy, we use mu_p/sigma_p
        record_df.to_csv(f"./exp_report/{args.dataset}/csv/perturbed_{args.epsilon_p}.csv", index=False)
        print(
            f"[PERTURBED][{mode.upper()}] Profit: {profit_p:.5f}, On-Time Ratio: {on_time_ratio_p:.5f}, Time: {test_time_p:.2f}s"
        )
        print(
            f"[PERTURBED][{mode.upper()}] Profit Min Percentiles: {profit_min_percent_p}"
        )
        print(
            f"[PERTURBED][{mode.upper()}] Profit + On-Time Ratio: {profit_p + on_time_ratio_p:.5f}"
        )

        print(f"\nEvaluating Decision Maker (UNPERTURBED) in {mode.upper()}...")
        profit_u, on_time_ratio_u, profit_min_percent_u, test_time_u, record_df_u = (
            session_unperturbed.dm_test(
                mode, epsilon_p=args.epsilon_p, random_noise=args.random_noise,
                return_detail=True,
            )
        )
        record_df_u.to_csv(f"./exp_report/{args.dataset}/csv/unperturbed_{args.epsilon_p}.csv", index=False)
        print(
            f"[UNPERTURBED][{mode.upper()}] Profit: {profit_u:.5f}, On-Time Ratio: {on_time_ratio_u:.5f}, Time: {test_time_u:.2f}s"
        )
        print(
            f"[UNPERTURBED][{mode.upper()}] Profit Min Percentiles: {profit_min_percent_u}"
        )
        print(
            f"[UNPERTURBED][{mode.upper()}] Profit + On-Time Ratio: {profit_u + on_time_ratio_u:.5f}"
        )

        # Print comparison
        print(
            f"\n[COMPARISON][{mode.upper()}] ΔProfit: {profit_p - profit_u:.5f}, ΔOn-Time Ratio: {on_time_ratio_p - on_time_ratio_u:.5f}, Δ(Profit+On-Time): {(profit_p + on_time_ratio_p) - (profit_u + on_time_ratio_u):.5f}"
        )


if __name__ == "__main__":
    main()
