import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
import os
import torch
import argparse
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from models.recalibrator import Recalibrator
from loaders.s_loader import S_Loader
from environments.environment import Env  # Use Env, not Environment
from sessions.cb_session import CB_Session
from tools import feature_list

# ===================== Argument Parsing =====================
def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate Decision Maker with Calibrated Predictor")
    parser.add_argument('--predictor_ckpt', type=str, required=True, help='Path to trained predictor checkpoint')
    parser.add_argument('--dm_ckpt', type=str, required=True, help='Path to trained decision maker checkpoint')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size for evaluation')
    # ------------------------ Model Setting
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    # parser.add_argument('--teacher_forcing_ratio', type=float, default=0.5)

    # ----------------------- Regularizer coefficient
    parser.add_argument("--decay_coeff", type=float, default=0.00001)
    parser.add_argument("--dm_decay_coeff", type=float, default=0.0005)

    # parser.add_argument('--gl_coeff', type=float, default=1)

    parser.add_argument("--mi_coeff", type=float, default=10)
    parser.add_argument("--ma_coeff", type=float, default=1)

    parser.add_argument("--otr_reward_coeff", type=float, default=1)

    parser.add_argument("--reward_smoothing_factor", type=float, default=0.5)

    parser.add_argument("--mip_coeff", type=float, default=1)
    parser.add_argument("--mil_coeff", type=float, default=1)

    # ----------------------- logger
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)

    parser.add_argument("--lr", type=float, default=0.01)

    # parser.add_argument('--mi_lr', type=float, default=0.0001)
    parser.add_argument("--dm_lr", type=float, default=0.01)
    
    parser.add_argument("--use_calibration", type=bool, default=False)
    parser.add_argument("--use_perturbation", type=bool, default=False)
    parser.add_argument("--eta_c", type=float, default=1)
    parser.add_argument("--eta_p", type=float, default=1)
    parser.add_argument('--dm_epochs', type=int, default=1, help='Dummy for compatibility')
    parser.add_argument('--eva_interval', type=int, default=1, help='Dummy for compatibility')
    parser.add_argument('--early_stop', type=int, default=10, help='Dummy for compatibility')
    parser.add_argument('--suffix', type=str, default='eval', help='Suffix for checkpoint naming')
    parser.add_argument('--CKPT_PATH', type=str, default='./', help='Checkpoint directory')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--use_gpu', action='store_true', help='Use GPU if available')
    parser.add_argument('--device_id', type=str, default='cuda:0', help='Device id for torch')
    return parser.parse_args()

# ===================== Main Evaluation Script =====================
def main():
    args = parse_args()
    args.epochs = 0  # For compatibility
    args.ckpt_start_epoch = 0  # For compatibility
    args.wandb = False  # Disable wandb logging for evaluation

    # Set up environment and dataset
    env = Env(args)
    dataset = S_Loader(env)
    env.feature_classes = dataset.feature_classes

    # Build and load predictor (S_SimDec)
    model = S_SimDec(env).to(env.device)
    model.load_state_dict(torch.load(args.predictor_ckpt, map_location=env.device))
    model.eval()

    # If calibration is used, build recalibrator and attach to session
    recalibrator = None
    if args.use_calibration:
        feature_dim = len(
            feature_list.product_info[args.dataset]
            + feature_list.order_info[args.dataset]
            + feature_list.customer_info[args.dataset]
            + feature_list.shipping_info[args.dataset]
        )
        action_dim = 4  # Should match your setup
        recalibrator = Recalibrator(K=action_dim, input_dim=feature_dim).to(env.device)
        recalibrator.eval()
        # Optionally load recalibrator weights if you have them
        # recalibrator.load_state_dict(torch.load(args.recalibrator_ckpt, map_location=env.device))

    # Build and load decision maker (ValueNetwork)
    value_network = ValueNetwork(env).to(env.device)
    value_network.load_state_dict(torch.load(args.dm_ckpt, map_location=env.device))
    value_network.eval()

    # Set up CB_Session
    session = CB_Session(env, model, dataset)
    session.init_value_network(value_network)
    if recalibrator is not None:
        session.recalibrator = recalibrator
        session.use_calibration = True

    # ===== Ensure StandardScaler is fitted before evaluation =====
    # This mimics cb_session.py logic: fit scaler on training data before using for val/test
    if hasattr(session, 'scaler') and hasattr(session, 'train_inputs'):
        session.scaler.fit(session.train_inputs)

    # Evaluate decision maker on validation and test sets
    print("Evaluating Decision Maker with Calibrated Predictor...")
    for mode in ["val", "test"]:
        profit, on_time_ratio, profit_min_percent, test_time = session.dm_test(mode)
        print(f"[{mode.upper()}] Profit: {profit:.5f}, On-Time Ratio: {on_time_ratio:.5f}, Time: {test_time:.2f}s")
        print(f"[{mode.upper()}] Profit Min Percentiles: {profit_min_percent}")

if __name__ == "__main__":
    main() 