import gym
import torch
from gym import spaces
import sys, pathlib
import numpy as np

sys.path.append(str(pathlib.Path(__file__).parent.parent))
from tools import feature_list
from sklearn.preprocessing import StandardScaler

class ShippingEnv(gym.Env):
    def __init__(self, model, loader, cost_dic, avg_profit, beta=0.5, random_noise=False, epsilon_p=0.0):
        """
        初始化运输决策的强化学习环境
        :param model: 模型 (S_SimDec)
        :param loader: 数据加载器 (S_Loader)
        :param cost_dic: 历史成本数据
        :param avg_profit: 平均利润数据
        :param beta: 权重，用于平衡利润和准时率
        """
        super(ShippingEnv, self).__init__()
        self.model = model
        self.loader = loader
        self.cost_dic_data = cost_dic[:, :-1]
        self.cost_dic_y = cost_dic[:, -1]
        self.avg_profit = avg_profit
        self.beta = beta

        # 动作空间：4种运输模式
        self.action_space = spaces.Discrete(4)
        self.scaler = StandardScaler()

        # 状态空间：订单特征维度
        feature_dim = len(
            feature_list.product_info[loader.env.args.dataset]
            + feature_list.order_info[loader.env.args.dataset]
            + feature_list.customer_info[loader.env.args.dataset]
            + feature_list.shipping_info[loader.env.args.dataset]
        )
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(feature_dim,), dtype=np.float32
        )
        self.feature_dim = feature_dim
        # 初始化当前状态
        self.current_state = None
        self.index = 0  # 当前订单索引
        self.random_noise = random_noise
        self.epsilon_p = epsilon_p
    
    def add_noise(self, data: torch.Tensor):
        noise = self.epsilon_p * torch.randn_like(data)
        return data + noise

    def reset(self):
        """
        初始化环境并返回第一个状态
        """
        self.index = 0
        self.current_state = self.loader.test_inputs[self.index].cpu().numpy()
        return self.current_state

    def step(self, action):
        """
        执行一个动作，返回下一个状态、奖励、是否结束以及额外信息
        :param action: 选择的运输模式 (0-3)
        :return: (next_state, reward, done, info)
        """
        input_data = self.current_state.copy()
        input_data = np.expand_dims(input_data, axis=0)
        input_data[:, -1] = action  # 设置运输模式

        input_data_scaled = self.scaler.fit_transform(input_data)
        input_data_tensor = torch.FloatTensor(input_data_scaled).to(
            self.loader.env.device
        )
        if self.random_noise:
            input_data_tensor = self.add_noise(input_data_tensor)

        import faiss

        index = faiss.IndexFlatL2(self.cost_dic_data.shape[1])
        index.add(self.cost_dic_data)

        with torch.no_grad():
            state = input_data_tensor[:, : self.feature_dim]
            selected_embedding = (
                self.model.embedding.weight[action, :]
                .unsqueeze(0)
                .repeat(state.shape[0], 1)
            )
            
            predicted_tokens = self.model(
                state,
                selected_embedding,
                input_data_tensor[:, self.feature_dim + 1 :],
            )

            query_vectors = np.array(
                [
                    [
                        input_data[
                            i,
                            feature_list.retrieve_index[self.loader.env.args.dataset][
                                0
                            ],
                        ],
                        input_data[
                            i,
                            feature_list.retrieve_index[self.loader.env.args.dataset][
                                1
                            ],
                        ],
                        action,
                    ]
                    for i in range(len(input_data_tensor))
                ],
                dtype="float32",
            )

            _, nearest_indices = index.search(query_vectors, 1)
            nearest_samples = self.cost_dic_data[nearest_indices.flatten()]

            profit = np.array(
                [
                    (
                        self.cost_dic_y[nearest_indices[i, 0]]
                        if np.array_equal(query, nearest_samples[i])
                        else self.avg_profit[action]
                    )
                    for i, query in enumerate(query_vectors)
                ]
            )
            on_time = predicted_tokens[-1].argmax(dim=1).cpu().numpy().mean()

                
        reward = profit + self.beta * on_time
        # print(profit, self.beta * on_time)

        self.index += 1
        done = self.index >= len(self.loader.test_inputs)
        if not done:
            self.current_state = self.loader.test_inputs[self.index].cpu().numpy()
        
        return (
            self.current_state if not done else None,
            reward,
            done,
            {"profit": profit.mean(), "on_time": on_time},
        )

        def render(self, mode="human"):
            """
            可选：实现渲染以便观察
            """
            pass