import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
import numpy as np

from tools import feature_list
from environments.rl_shipping_env import ShippingEnv
import argparse
import time
import torch
from tools.logger import info
from environments.environment import Env
from loaders.s_loader import S_Loader
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from models.perturbator import Perturbator
import torch.nn as nn
import torch.optim as optim
from evaluations.rl_evaluation import evaluate_model


class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

        self.feature_dim = len(
            feature_list.product_info[args.dataset]
            + feature_list.order_info[args.dataset]
            + feature_list.customer_info[args.dataset]
            + feature_list.shipping_info[args.dataset]
        )

    def forward(self, x):
        x = torch.relu(self.fc1(x[0][: self.feature_dim].unsqueeze(0)))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


def train_dqn(env, train_inputs, episodes=500, gamma=0.99, epsilon=0.3, lr=1e-3):
    """
    使用强化学习训练DQN模型
    :param env: 强化学习环境
    :param train_inputs: 训练数据
    :param episodes: 训练的总轮数
    :param gamma: 折扣因子
    :param epsilon: 探索率
    :param lr: 学习率
    :return: 训练好的模型
    """
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    model = DQN(state_dim, action_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for episode in range(episodes):
        env.reset()  # 初始化环境
        state = train_inputs[0].cpu().numpy()  # 使用训练数据初始化状态
        total_reward = 0

        for step in range(len(train_inputs)):  # 每轮训练处理所有订单
            state_tensor = torch.FloatTensor(state).unsqueeze(0)

            # epsilon-greedy策略
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q_values = model(state_tensor)
                    action = torch.argmax(q_values).item()

            next_state, reward, done, _ = env.step(action)
            total_reward += reward

            # 计算目标Q值
            with torch.no_grad():
                next_state_tensor = (
                    torch.FloatTensor(next_state).unsqueeze(0)
                    if next_state is not None
                    else None
                )
                max_next_q = (
                    model(next_state_tensor).max().item()
                    if next_state_tensor is not None
                    else 0
                )
                target_q = reward + gamma * max_next_q

            # 更新Q值
            q_values = model(state_tensor)
            loss = criterion(q_values[0, action], torch.tensor(target_q[0]).float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            state = next_state
            if done:
                break

        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward[0]:.2f}")

    return model


def parse_args():
    parser = argparse.ArgumentParser(description="AI4Simulation")

    # ----------------------- Device Setting
    parser.add_argument("--use_gpu", type=int, default=1)
    parser.add_argument("--device_id", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42)

    # ------------------------ Training Setting
    # DataCo
    parser.add_argument(
        "--ckpt",
        type=str,
        default="./exp_report/OAS/ckpt/07-14-16_OAS_epoch17_sim_True_1.pth",
    )

    # LSCRW
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/LSCRW/ckpt/01-17-14_LSCRW_epoch57.pth')

    # GlobalStore
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/GlobalStore/ckpt/01-17-14_GlobalStore_epoch476.pth')

    # OAS
    # parser.add_argument('--ckpt', type=str, default='/home/local/ASURITE/haoyueba/AI4Simulation_SuppluChain/exp_report/OAS/ckpt/01-17-14_OAS_epoch223.pth')

    # parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument("--ckpt_start_epoch", type=int, default=0)

    parser.add_argument(
        "--dataset",
        type=str,
        default="OAS",
        choices=["LSCRW", "DataCo", "GlobalStore", "OAS"],
    )
    parser.add_argument("--lr", type=float, default=0.01)

    # parser.add_argument('--mi_lr', type=float, default=0.0001)
    parser.add_argument("--dm_lr", type=float, default=0.01)

    parser.add_argument("--epochs", type=int, default=10000)
    parser.add_argument("--dm_epochs", type=int, default=6000)
    parser.add_argument("--eva_interval", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=2048)

    parser.add_argument("--early_stop", type=int, default=50)

    parser.add_argument(
        "--train_mode",
        type=int,
        default=2,
        help="1 means traning both simulator and decision-maker, 1 means training simulator only, 2 means training decision-maker only",
    )

    # ------------------------ Model Setting
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    # parser.add_argument('--teacher_forcing_ratio', type=float, default=0.5)

    # ----------------------- Regularizer coefficient
    parser.add_argument("--decay_coeff", type=float, default=5e-4)
    parser.add_argument("--dm_decay_coeff", type=float, default=5e-4)

    # parser.add_argument('--gl_coeff', type=float, default=1)

    parser.add_argument("--mi_coeff", type=float, default=1)
    parser.add_argument("--ma_coeff", type=float, default=0)

    parser.add_argument("--otr_reward_coeff", type=float, default=10)

    parser.add_argument("--reward_smoothing_factor", type=float, default=0.1)

    parser.add_argument("--p_coeff", type=float, default=1)

    # ----------------------- logger
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--epsilon_p", type=float, default=0.0)
    parser.add_argument("--random_noise", type=bool, default=False)
    return parser.parse_args()


t = time.time()
# ----------------------------------- Env Init -----------------------------------------------------------
info("--------------------------------Een Init----------------------------------")
args = parse_args()
my_env = Env(args)

# ----------------------------------- Dataset Init -----------------------------------------------------------
info("--------------------------------Dataset Init------------------------------")
my_loader = S_Loader(my_env)
my_env.feature_classes = my_loader.feature_classes

# ----------------------------------- Model Init -----------------------------------------------------------
info("--------------------------------Model Init--------------------------------")
my_model = S_SimDec(my_env)
if args.ckpt != None:
    my_model.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
v_model = ValueNetwork(my_env)

# ---------------------------------------- Main -----------------------------------------------------------
info("------------------------------------ Main --------------------------------")
cost_dic = my_loader.cost_mrp
avg_profit = my_loader.avg_profit

perturbator = Perturbator(
    predictor=my_model,
    policy=v_model,
    env=my_env,
    M=getattr(my_env.args, "perturb_M", 8),
    device=my_env.device,
    cost_dic=cost_dic,
    avg_profit=my_loader.avg_profit,
    feature_list=feature_list,
    otr_reward_coeff=getattr(my_env.args, "otr_reward_coeff", 1.0),
    retrieve_index=feature_list.retrieve_index[my_env.args.dataset],
    action_dim=4,
)

# 初始化强化学习环境
env = ShippingEnv(
    my_model,
    my_loader,
    cost_dic,
    avg_profit,
    beta=0.3,
    random_noise=args.random_noise,
    epsilon_p=args.epsilon_p,
)

# # 使用训练数据训练模型
# trained_model = train_dqn(
#     env, my_loader.train_inputs, episodes=15, gamma=0.99, epsilon=0.3, lr=1e-4
# )

# torch.save(trained_model.state_dict(), f"./exp_report/{args.dataset}/ckpt/dqn_model.pth")
# print("Model saved to dqn_model.pth")
# print(time.time() - t)


state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

trained_model = DQN(state_dim, action_dim)
trained_model.load_state_dict(
    torch.load(f"./exp_report/{args.dataset}/ckpt/dqn_model.pth", map_location="cpu")
)
evaluate_model(
    env,
    trained_model,
    my_loader.test_inputs,
    # random_noise=args.random_noise,
    # epsilon_p=args.epsilon_p,
)
