from faiss import rand
import numpy as np
from pulp import LpMaximize, LpProblem, LpVariable, lpSum

import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
from tools import feature_list
import argparse
import time
import torch
import wandb
from tools.logger import info
from environments.environment import Env
from loaders.s_loader import S_Loader
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from sklearn.preprocessing import StandardScaler
from models.perturbator import Perturbator


class LinearProgrammingDecision:
    def __init__(self, model, env, cost_dic, avg_profit):
        """
        初始化决策类
        :param model: 模型 (self.model)
        :param env: 环境参数 (self.env)
        :param scaler: 数据标准化工具 (self.scaler)
        :param cost_dic: 历史成本数据
        :param avg_profit: 平均利润数据
        """
        self.model = model
        self.env = env
        self.scaler = StandardScaler()
        self.cost_dic_data = cost_dic[:, :-1]
        self.cost_dic_y = cost_dic[:, -1]
        self.avg_profit = avg_profit

    def build_faiss_index(self):
        """
        构建 FAISS 索引
        """
        import faiss

        index = faiss.IndexFlatL2(self.cost_dic_data.shape[1])
        index.add(self.cost_dic_data)
        return index

    def get_query_vectors(self, input_data, action, feature_list_key):
        """
        生成用于 FAISS 查询的向量
        :param input_data: 输入数据
        :param action: 当前运输方式编号
        :param feature_list_key: 数据集名称，用于定位检索特征索引
        :return: query_vectors
        """
        retrieve_index = feature_list.retrieve_index[self.env.args.dataset]
        return np.array(
            [
                [
                    input_data[i, retrieve_index[0]].cpu().item(),
                    input_data[i, retrieve_index[1]].cpu().item(),
                    action,
                ]
                for i in range(input_data.shape[0])
            ],
            dtype="float32",
        )

    def compute_profit_with_nn(self, index, query_vectors, action):
        """
        使用 FAISS 查找最近邻，估算利润
        :param index: FAISS 索引
        :param query_vectors: 查询向量
        :param action: 当前运输方式编号
        :return: profit 数组
        """
        _, nearest_indices = index.search(query_vectors, 1)
        nearest_samples = self.cost_dic_data[nearest_indices.flatten()]

        profit = np.array(
            [
                (
                    self.cost_dic_y[nearest_indices[i, 0]]
                    if np.array_equal(query, nearest_samples[i])
                    else self.avg_profit[action]
                )
                for i, query in enumerate(query_vectors)
            ]
        )
        return profit

    def prepare_input_for_action(self, input_data, decision_index, action):
        """
        为当前 action 准备输入数据
        :param input_data: 原始输入数据
        :param decision_index: 运输方式字段的索引
        :param action: 当前运输方式编号
        :return: 标准化后的数据
        """
        input_data[:, decision_index] = action
        input_data_scaled = self.scaler.fit_transform(input_data)
        return torch.FloatTensor(input_data_scaled).to(self.env.device)

    def model_predict(self, state, action, decision_index, input_data):
        """
        使用模型预测 on_time
        :param input_data_scaled: 标准化后的输入数据
        :param action: 当前运输方式
        :param decision_index: 决策字段索引
        :param input_data: 原始输入数据（用于非嵌入部分）
        :return: on_time 预测结果
        """
        with torch.no_grad():
            selected_embedding = (
                self.model.embedding.weight[action, :]
                .unsqueeze(0)
                .repeat(state.shape[0], 1)
            )
            predicted_tokens = self.model(
                state, selected_embedding, input_data[:, decision_index + 1 :]
            )
            on_time = predicted_tokens[-1].argmax(dim=1).cpu().numpy()
        return on_time

    def update_matrices(self, profit_matrix, on_time_matrix, action, profit, on_time):
        """
        更新 profit_matrix 和 on_time_matrix
        """
        profit_matrix[:, action] = profit
        on_time_matrix[:, action] = on_time.mean()
        return profit_matrix, on_time_matrix

    def calculate_profit_and_on_time(
        self, input_data, random_noise=False, epsilon_p=0.0
    ):
        """
        使用 self.model 计算每个订单在不同 shipping_mode 下的 profit 和 on_time_ratio
        :param input_data: 输入数据 (torch.Tensor)
        :return: (profit_matrix, on_time_matrix)
        """
        num_orders = input_data.shape[0]
        num_actions = 4  # 假设有四种运输模式

        profit_matrix = np.zeros((num_orders, num_actions))
        on_time_matrix = np.zeros((num_orders, num_actions))

        index = self.build_faiss_index()

        decision_index = len(
            feature_list.product_info[self.env.args.dataset]
            + feature_list.order_info[self.env.args.dataset]
            + feature_list.customer_info[self.env.args.dataset]
            + feature_list.shipping_info[self.env.args.dataset]
        )

        for action in range(num_actions):
            input_data_scaled = self.prepare_input_for_action(
                input_data, decision_index, action
            )

            # 提取订单特征
            if random_noise:
                noise = torch.randn_like(input_data_scaled) * epsilon_p
                input_data_scaled += noise
            state = input_data_scaled[:, :decision_index]
            state_org = state.clone().detach()

            # 生成查询向量
            query_vectors = self.get_query_vectors(
                input_data, action, self.env.args.dataset
            )

            # 获取利润
            profit = self.compute_profit_with_nn(index, query_vectors, action)

            if perturbator is not None and epsilon_p > 0.0:
                info("-------perturbation---------")
                with torch.no_grad():
                    z, sigma = perturbator.encode(state)
                    zs = perturbator.sample_perturbations(z, sigma, epsilon_p=epsilon_p)
                    perturbed_state = perturbator.decode(zs[1])
                    perturbed_state = perturbator.inverse_transform(perturbed_state)
                    state = perturbed_state.to(self.env.device)
            else:
                state = state

            comparison = torch.eq(state_org, state)
            # 找出不相等的元素
            diff_mask = ~comparison

            # 获取tensor1和tensor2中不同的元素
            diff_elements = state_org[diff_mask]

            print("Different elements in state after perturbation:", diff_elements)
            # 模型预测 on_time
            on_time = self.model_predict(state, action, decision_index, input_data)

            # 更新矩阵
            profit_matrix, on_time_matrix = self.update_matrices(
                profit_matrix, on_time_matrix, action, profit, on_time
            )

        return profit_matrix, on_time_matrix

    def optimize_decisions(self, input_data, epsilon_p=0.0, random_noise=False):
        """
        使用线性规划优化决策，并返回最优指标
        :param input_data: 输入订单特征数据 (torch.Tensor)
        :param alpha: 权重，profit 的重要性
        :param beta: 权重，on_time_ratio 的重要性
        :return: (optimal_modes, final_profit, final_on_time_ratio, profit_min_percent)
        """

        # 计算 profit 和 on_time_ratio 矩阵
        profit_matrix, on_time_matrix = self.calculate_profit_and_on_time(
            input_data, random_noise=random_noise, epsilon_p=epsilon_p
        )

        # 初始化线性规划问题
        problem = LpProblem("Maximize_Profit_and_OnTime", LpMaximize)
        num_orders, num_actions = profit_matrix.shape

        # 定义决策变量
        decision_vars = [
            [LpVariable(f"decision_{i}_{j}", cat="Binary") for j in range(num_actions)]
            for i in range(num_orders)
        ]

        # 每个订单只能选择一种 action
        for i in range(num_orders):
            problem += lpSum(decision_vars[i]) == 1, f"OneActionPerOrder_{i}"

        # 定义目标函数
        problem += lpSum(
            0.1 * profit_matrix[i, j] * decision_vars[i][j]
            + 0.5 * on_time_matrix[i, j] * decision_vars[i][j]
            for i in range(num_orders)
            for j in range(num_actions)
        )

        # 求解
        problem.solve()

        # 提取最优解
        optimal_modes = [
            np.argmax([decision_vars[i][j].varValue for j in range(num_actions)])
            for i in range(num_orders)
        ]

        # 计算最优指标
        final_profit = profit_matrix[np.arange(num_orders), optimal_modes].mean()
        final_on_time_ratio = on_time_matrix[
            np.arange(num_orders), optimal_modes
        ].mean()

        # 计算最低 10%、20%、30% 的利润
        sorted_profits = np.sort(profit_matrix[np.arange(num_orders), optimal_modes])
        thresholds = [0.1, 0.2, 0.3]
        profit_min_percent = {
            threshold: sorted_profits[int(threshold * len(sorted_profits))]
            for threshold in thresholds
        }

        return optimal_modes, final_profit, final_on_time_ratio, profit_min_percent


def parse_args():
    parser = argparse.ArgumentParser(description="AI4Simulation")

    # ----------------------- Device Setting
    parser.add_argument("--use_gpu", type=int, default=1)
    parser.add_argument("--device_id", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42)

    # ------------------------ Training Setting
    # DataCo
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/DataCo/ckpt/01-17-15_DataCo_epoch817.pth')

    # LSCRW
    parser.add_argument(
        "--ckpt",
        type=str,
        default="./exp_report/OAS/ckpt/07-14-16_OAS_epoch17_sim_True_1.pth",
    )

    # GlobalStore
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/GlobalStore/ckpt/01-17-14_GlobalStore_epoch476.pth')

    # OAS
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/OAS/ckpt/01-17-14_OAS_epoch223.pth')

    # parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument("--ckpt_start_epoch", type=int, default=0)

    parser.add_argument(
        "--dataset",
        type=str,
        default="OAS",
        choices=["LSCRW", "DataCo", "GlobalStore", "OAS"],
    )
    parser.add_argument("--lr", type=float, default=0.01)

    # parser.add_argument('--mi_lr', type=float, default=0.0001)
    parser.add_argument("--dm_lr", type=float, default=0.01)

    parser.add_argument("--epochs", type=int, default=10000)
    parser.add_argument("--dm_epochs", type=int, default=6000)
    parser.add_argument("--eva_interval", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=2048)

    parser.add_argument("--early_stop", type=int, default=50)

    parser.add_argument(
        "--train_mode",
        type=int,
        default=0,
        help="1 means traning both simulator and decision-maker, 1 means training simulator only, 2 means training decision-maker only",
    )

    # ------------------------ Model Setting
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    # parser.add_argument('--teacher_forcing_ratio', type=float, default=0.5)

    # ----------------------- Regularizer coefficient
    parser.add_argument("--decay_coeff", type=float, default=5e-4)
    parser.add_argument("--dm_decay_coeff", type=float, default=5e-4)

    # parser.add_argument('--gl_coeff', type=float, default=1)

    parser.add_argument("--mi_coeff", type=float, default=1)
    parser.add_argument("--ma_coeff", type=float, default=0)

    parser.add_argument("--otr_reward_coeff", type=float, default=10)

    parser.add_argument("--reward_smoothing_factor", type=float, default=0.1)

    parser.add_argument("--p_coeff", type=float, default=1)

    # ----------------------- logger
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--epsilon_p", type=float, default=0.0)
    parser.add_argument("--random_noise", type=bool, default=False)
    return parser.parse_args()

    return parser.parse_args()


t = time.time()
# ----------------------------------- Env Init -----------------------------------------------------------
info("--------------------------------Een Init----------------------------------")
args = parse_args()
my_env = Env(args)

# ----------------------------------- Dataset Init -----------------------------------------------------------
info("--------------------------------Dataset Init------------------------------")
my_loader = S_Loader(my_env)
my_env.feature_classes = my_loader.feature_classes

# ----------------------------------- Model Init -----------------------------------------------------------
info("--------------------------------Model Init--------------------------------")
my_model = S_SimDec(my_env)
if args.ckpt != None:
    my_model.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
v_model = ValueNetwork(my_env)

# ---------------------------------------- Main -----------------------------------------------------------
info("------------------------------------ Main --------------------------------")
cost_dic = my_loader.cost_mrp
avg_profit = my_loader.avg_profit


# 示例：使用 LinearProgrammingDecision
# 假设 model, env, scaler, cost_dic, avg_profit 已初始化
optimizer = LinearProgrammingDecision(my_model, my_env, cost_dic, avg_profit)

perturbator = Perturbator(
    predictor=my_model,
    policy=v_model,
    env=my_env,
    M=getattr(my_env.args, "perturb_M", 8),
    device=my_env.device,
    cost_dic=cost_dic,
    avg_profit=my_loader.avg_profit,
    feature_list=feature_list,
    otr_reward_coeff=getattr(my_env.args, "otr_reward_coeff", 1.0),
    retrieve_index=feature_list.retrieve_index[my_env.args.dataset],
    action_dim=4,
)
# 假设 input_data 是加载后的订单特征 (torch.Tensor)
optimal_modes, final_profit, final_on_time_ratio, profit_min_percent = (
    optimizer.optimize_decisions(
        my_loader.test_inputs, epsilon_p=args.epsilon_p, random_noise=args.random_noise
    )
)

print("优化后的运输模式：", optimal_modes)
print(f"最终利润：{final_profit:.5f}")
print(f"最终准时率：{final_on_time_ratio:.5f}")
print(f"最终利润 + 最终准时率：{final_profit+final_on_time_ratio:.5f}")
print("最低利润百分比：", profit_min_percent)
print(time.time() - t)
