import openai
import gym
from gym import spaces
import numpy as np
import sys
sys.path.append(str(pathlib.Path(__file__).parent.parent))
from tools import feature_list
import argparse
import time
import torch
import wandb
from tools.logger import info
from environments.environment import Env
from loaders.s_loader import S_Loader
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.optim as optim
import wandb

class ShippingEnv(gym.Env):
    def __init__(self, model, loader, cost_dic, avg_profit, beta=0.5):
        """
        初始化运输决策的强化学习环境
        :param model: 模型 (S_SimDec)
        :param loader: 数据加载器 (S_Loader)
        :param cost_dic: 历史成本数据
        :param avg_profit: 平均利润数据
        :param beta: 权重，用于平衡利润和准时率
        """
        super(ShippingEnv, self).__init__()
        self.model = model
        self.loader = loader
        self.cost_dic_data = cost_dic[:, :-1]
        self.cost_dic_y = cost_dic[:, -1]
        self.avg_profit = avg_profit
        self.beta = beta

        # 动作空间：4种运输模式
        self.action_space = spaces.Discrete(4)
        self.scaler = StandardScaler()

        # 状态空间：订单特征维度
        feature_dim = len(feature_list.product_info[loader.env.args.dataset] + 
                          feature_list.order_info[loader.env.args.dataset] +
                          feature_list.customer_info[loader.env.args.dataset] + 
                          feature_list.shipping_info[loader.env.args.dataset])
        self.observation_space = spaces.Box(low=0, high=1, shape=(feature_dim,), dtype=np.float32)
        self.feature_dim = feature_dim
        # 初始化当前状态
        self.current_state = None
        self.index = 0  # 当前订单索引

    def reset(self):
        """
        初始化环境并返回第一个状态
        """
        self.index = 0
        self.current_state = self.loader.test_inputs[self.index].cpu().numpy()
        return self.current_state


    def step(self, action):
        """
        执行一个动作，返回下一个状态、奖励、是否结束以及额外信息
        :param action: 选择的运输模式 (0-3)
        :return: (next_state, reward, done, info)
        """
        input_data = self.current_state.copy()
        input_data = np.expand_dims(input_data, axis=0)
        input_data[:, -1] = action  # 设置运输模式

        input_data_scaled = self.scaler.fit_transform(input_data)
        input_data_tensor = torch.FloatTensor(input_data_scaled).to(self.loader.env.device)

        import faiss
        index = faiss.IndexFlatL2(self.cost_dic_data.shape[1])
        index.add(self.cost_dic_data)

        with torch.no_grad():
            selected_embedding = self.model.embedding.weight[action, :].unsqueeze(0).repeat(input_data_tensor[:, :self.feature_dim].shape[0], 1)
            predicted_tokens = self.model(input_data_tensor[:, :self.feature_dim], selected_embedding,
                                        input_data_tensor[:, self.feature_dim+1:])

            query_vectors = np.array([
                [
                    input_data[i, feature_list.retrieva_index[self.loader.env.args.dataset][0]],
                    input_data[i, feature_list.retrieva_index[self.loader.env.args.dataset][1]],
                    action
                ]
                for i in range(len(input_data_tensor))
            ], dtype='float32')

            _, nearest_indices = index.search(query_vectors, 1)
            nearest_samples = self.cost_dic_data[nearest_indices.flatten()]

            profit = np.array([
                self.cost_dic_y[nearest_indices[i, 0]] if np.array_equal(query, nearest_samples[i]) else self.avg_profit[action]
                for i, query in enumerate(query_vectors)
            ])
            on_time = predicted_tokens[-1].argmax(dim=1).cpu().numpy().mean()

        reward = profit + self.beta * on_time
        # print(profit, self.beta * on_time)

        self.index += 1
        done = self.index >= len(self.loader.test_inputs)
        if not done:
            self.current_state = self.loader.test_inputs[self.index].cpu().numpy()

        return self.current_state if not done else None, reward, done, {"profit": profit.mean(), "on_time": on_time}

        def render(self, mode='human'):
            """
            可选：实现渲染以便观察
            """
            pass



class LLMDecisionMaker:
    def __init__(self, api_key, model="gpt-3.5-turbo"):
        """
        初始化LLM决策器
        :param api_key: OpenAI API Key
        :param model: 使用的语言模型名称（如 'gpt-4' 或 'gpt-3.5-turbo'）
        """
        openai.api_key = api_key
        self.model = model

    def generate_batch_prompt(self, batch_order_features, features_name):
        """
        生成给 LLM 的批量输入提示
        :param batch_order_features: 当前批次订单的特征列表
        :param features_name: 特征名称列表
        :return: LLM 的输入提示文本
        """
        prompt = "You are a logistics optimization assistant responsible for selecting the best shipping mode (Standard Class, First Class, Second Class, Same Day), corresponding to the codes (0, 1, 2, 3).\n\n"
        prompt += "For each order, you need to maximize the weighted sum of the following metrics:\n1. Profit\n2. On-time delivery rate\n\n"
        prompt += "Here are the orders and their features:\n"
        
        for i, order_features in enumerate(batch_order_features):
            prompt += f"Order {i + 1}:\nFeatures: {features_name}\nValues: {order_features}\n\n"
        
        prompt += "For each order, provide the best shipping mode in the format:\nOrder <Order Number>: Mode <Shipping Mode Code>\n"
        return prompt
    
    def generate_prompt(self, order_features, features_name):
        """
        生成给 LLM 的输入提示
        :param order_features: 当前订单的特征
        :param decision_context: 决策上下文（如目标、限制条件）
        :return: LLM 的输入提示文本
        # """
        prompt = f"""
        You are a logistics optimization assistant responsible for selecting the best shipping mode (Standard Class, First Class, Second Class, Same Day), corresponding to the codes (0, 1, 2, 3).
        The current order has the following features:
        {features_name}

        Feature values are:
        {order_features}

        The decision goal is to maximize the weighted sum of the following metrics:
        1. Profit
        2. On-time delivery rate

        Please choose the best shipping mode (0, 1, 2, 3) based on the above goals.
        Output format:
        Mode: <Shipping Mode Code>
        """
        # print(prompt)
        # 你是一名物流优化助手，负责为每个订单选择最佳运输模式（Standard Class, First Class, Second Class, Same Day），对应编号为（0, 1, 2, 3)。
        # 当前订单的特征包括:
        # {features_name}

        # 特征内容为:
        # {order_features}

        # 决策目标是最大化以下指标的加权和：
        # 1. 利润
        # 2. 准时率

        # 请根据上述目标选择一个运输模式（0, 1, 2, 3）。
        # 输出格式为：
        # 模式: <运输模式编号>
        
        return prompt
    



    def make_batch_decisions(self, batch_order_features, features_name):
        """
        使用 LLM 为一批订单生成决策
        :param batch_order_features: 当前批次订单的特征列表
        :param features_name: 特征名称列表
        :return: 每条订单的决策结果
        """
        prompt = self.generate_batch_prompt(batch_order_features, features_name)
        max_retries = 3
        retry_delay = 10  # 每次重试的间隔时间（秒）

        # for attempt in range(max_retries):
        #     try:
        response = openai.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}]
        )
        content = response.choices[0].message.content
        return self.parse_batch_decisions(content)
        #     except:
        #         print(f"Rate limit reached: Retrying in {retry_delay} seconds...")
        #         time.sleep(retry_delay)
        # raise Exception("Exceeded maximum retry attempts due to rate limits.")

    def parse_batch_decisions(self, response_content):
        """
        解析 LLM 的批量决策输出
        :param response_content: LLM 的输出文本
        :return: 每条订单的决策列表
        """
        decisions = []
        for line in response_content.splitlines():
            if line.startswith("Order"):
                try:
                    mode = int(line.split("Mode")[1].strip())
                    if mode in [0, 1, 2, 3]:  # 验证是否为有效模式
                        decisions.append(mode)
                    else:
                        decisions.append(np.random.randint(0, 4))  # 无效模式时随机选择
                except:
                    print('randomly chose...')
                    decisions.append(np.random.randint(0, 4))  # 无法解析时随机选择
        return decisions


def evaluate_llm_decision(env, llm_decision_maker, test_inputs, batch_size=1000):
    """
    使用 LLM 决策并评估效果
    :param env: 环境实例
    :param llm_decision_maker: LLM 决策器
    :param test_inputs: 测试集数据
    :param batch_size: 批量大小
    """
    env.reset()
    num_orders = len(test_inputs)
    profit_list = []
    on_time_list = []
    optimal_modes = []

    feature_dim = len(feature_list.product_info[args.dataset] + 
                      feature_list.order_info[args.dataset] +
                      feature_list.customer_info[args.dataset] + 
                      feature_list.shipping_info[args.dataset])
    feature_name = feature_list.product_info[args.dataset] + feature_list.order_info[args.dataset] + \
                   feature_list.customer_info[args.dataset] + feature_list.shipping_info[args.dataset]

    for batch_start in range(0, num_orders, batch_size):
        batch_end = min(batch_start + batch_size, num_orders)
        batch_order_features = test_inputs[batch_start:batch_end].cpu().numpy()

        # 使用 LLM 生成批量决策
        t = time.time()
        batch_decisions = llm_decision_maker.make_batch_decisions(batch_order_features[:, :feature_dim], feature_name)
        print(f"Batch {batch_start // batch_size + 1}: LLM 决策输出完成")
        print(f"决策时间: {time.time() - t}")

        for i, action in enumerate(batch_decisions):
            # 执行动作
            next_state, reward, done, info = env.step(action)
            profit_list.append(info['profit'])
            on_time_list.append(info['on_time'])
            optimal_modes.append(action)

    # 计算评估指标
    final_profit = np.mean(profit_list)
    final_on_time_ratio = np.mean(on_time_list)

    sorted_profits = np.sort(profit_list)
    thresholds = [0.1, 0.2, 0.3]
    profit_min_percent = {threshold: sorted_profits[int(threshold * len(sorted_profits))]
                          for threshold in thresholds}

    # 输出评估结果
    print("优化后的运输模式：", optimal_modes)
    print(f"最终利润：{final_profit:.5f}")
    print(f"最终准时率：{final_on_time_ratio:.5f}")
    print(f"最终利润 + 最终准时率：{final_profit + final_on_time_ratio:.5f}")
    print("最低利润百分比：", profit_min_percent)


def parse_args():   
    parser = argparse.ArgumentParser(description="AI4Simulation")

    # ----------------------- Device Setting
    parser.add_argument('--use_gpu', type=int, default=1)
    parser.add_argument('--device_id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=42)

    # ------------------------ Training Setting
    # DataCo
    parser.add_argument('--ckpt', type=str, default="./exp_report/OAS/ckpt/07-14-16_OAS_epoch17_sim_True_1.pth")

    # LSCRW
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/LSCRW/ckpt/01-17-14_LSCRW_epoch57.pth')

    # GlobalStore
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/GlobalStore/ckpt/01-17-14_GlobalStore_epoch476.pth')

    # OAS
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/OAS/ckpt/01-17-14_OAS_epoch223.pth')



    # parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--ckpt_start_epoch', type=int, default=0)

    parser.add_argument('--dataset', type=str, default='DataCo', choices=['LSCRW', 'DataCo','GlobalStore','OAS'])
    parser.add_argument('--lr', type=float, default=0.01)

    # parser.add_argument('--mi_lr', type=float, default=0.0001)
    parser.add_argument('--dm_lr', type=float, default=0.01)

    parser.add_argument('--epochs', type=int, default=10000)
    parser.add_argument('--dm_epochs', type=int, default=6000)
    parser.add_argument('--eva_interval', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=2048)

    parser.add_argument('--early_stop', type=int, default=50)

    parser.add_argument('--train_mode', type=int, default=0, help='1 means traning both simulator and decision-maker, 1 means training simulator only, 2 means training decision-maker only')


    # ------------------------ Model Setting
    parser.add_argument('--embed_dim', type=int, default=64)
    parser.add_argument('--decoder_num_layers', type=int, default=1)
    parser.add_argument('--encoder_num_layers', type=int, default=1)
    # parser.add_argument('--teacher_forcing_ratio', type=float, default=0.5)


    # ----------------------- Regularizer coefficient
    parser.add_argument('--decay_coeff', type=float, default=5e-4)
    parser.add_argument('--dm_decay_coeff', type=float, default=5e-4)

    # parser.add_argument('--gl_coeff', type=float, default=1)

    parser.add_argument('--mi_coeff', type=float, default=1)
    parser.add_argument('--ma_coeff', type=float, default=0)

    parser.add_argument('--otr_reward_coeff', type=float, default=10)

    parser.add_argument('--reward_smoothing_factor', type=float, default=0.1)

    parser.add_argument('--p_coeff', type=float, default=1)


    # ----------------------- logger
    parser.add_argument('--wandb', type=int, default=0)
    parser.add_argument('--save', type=int, default=0)

    return parser.parse_args()



# ----------------------------------- Env Init -----------------------------------------------------------
info('--------------------------------Een Init----------------------------------')
args = parse_args()
my_env = Env(args)

# ----------------------------------- Dataset Init -----------------------------------------------------------
info('--------------------------------Dataset Init------------------------------')
my_loader = S_Loader(my_env)
my_env.feature_classes = my_loader.feature_classes

# ----------------------------------- Model Init -----------------------------------------------------------
info('--------------------------------Model Init--------------------------------')
my_model = S_SimDec(my_env)
if args.ckpt != None:
    my_model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
v_model = ValueNetwork(my_env)

# ---------------------------------------- Main -----------------------------------------------------------
info('------------------------------------ Main --------------------------------')
cost_dic = my_loader.cost_mrp
avg_profit = my_loader.avg_profit


# 初始化 LLM 决策器
api_key = ""
llm_decision_maker = LLMDecisionMaker(api_key)

# 初始化强化学习环境
env = ShippingEnv(my_model, my_loader, cost_dic, avg_profit, beta=0.5)

# 在测试集上评估 LLM 决策效果
evaluate_llm_decision(env, llm_decision_maker, my_loader.test_inputs, batch_size=25)
