import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
import torch
import numpy as np
from environments.rl_shipping_env import ShippingEnv
from models.perturbator import Perturbator


def evaluate_model(
    env: ShippingEnv,
    model,
    test_inputs,
    epsilon_p=0.0,
    random_noise=False,
):
    """
    评估强化学习决策的效果，支持DQN和PPO模型。

    :param env: 强化学习环境
    :param model: 训练好的模型（DQN或PPO）
    :param test_inputs: 测试数据
    """
    # 确认模型是DQN还是PPO
    is_dqn = hasattr(model, "fc1") and isinstance(
        getattr(model, "fc1"), torch.nn.Linear
    )
    is_ppo = hasattr(model, "actor") and isinstance(
        getattr(model, "actor"), torch.nn.Linear
    )

    assert is_dqn or is_ppo, "Model should be either DQN or PPO"

    env.reset()
    num_orders = len(test_inputs)
    profit_list = []
    on_time_list = []
    optimal_modes = []

    for i in range(num_orders):
        state = test_inputs[i].cpu().numpy() 

        state_tensor = torch.FloatTensor(state).unsqueeze(0)

        # perturbation
        if random_noise:
            noise = torch.randn_like(state_tensor) * epsilon_p
            state_tensor += noise
        
        # decision-making
        if is_dqn:
            # 对于DQN模型，选择具有最大Q值的动作
            with torch.no_grad():
                q_values = model(state_tensor)
                action = torch.argmax(q_values).item()
        elif is_ppo:
            # 对于PPO模型，从策略分布中采样动作
            with torch.no_grad():
                logits, _ = model(state_tensor)
                action_probs = torch.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(action_probs)
                action = dist.sample().item()

        next_state, reward, done, info = env.step(action)

        # 获取当前订单的利润和准时率
        profit_list.append(info["profit"])
        on_time_list.append(info["on_time"])
        optimal_modes.append(action)

        if done:
            break

    # 计算最终指标
    final_profit = np.mean(profit_list)
    final_on_time_ratio = np.mean(on_time_list)

    sorted_profits = np.sort(profit_list)
    thresholds = [0.1, 0.2, 0.3]
    profit_min_percent = {
        threshold: sorted_profits[int(threshold * len(sorted_profits))]
        for threshold in thresholds
    }

    # 输出指标
    print("优化后的运输模式：", optimal_modes)
    print(f"最终利润：{final_profit:.5f}")
    print(f"最终准时率：{final_on_time_ratio:.5f}")
    print(f"最终利润 + 最终准时率：{final_profit + final_on_time_ratio:.5f}")
    print("最低利润百分比：", profit_min_percent)
