import numpy as np
import gym
from gym import spaces
import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
from tools import feature_list
import argparse
import time
import torch
import wandb
from tools.logger import info
from environments.environment import Env
from loaders.s_loader import S_Loader
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
import torch.nn as nn
import torch.optim as optim
from evaluations.rl_evaluation import evaluate_model
from environments.rl_shipping_env import ShippingEnv
from models.perturbator import Perturbator


import torch.nn as nn
import torch.nn.functional as F


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()

        # 共享的特征提取层
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU()
        )

        # Actor网络：输出动作概率分布（分类）
        self.actor = nn.Linear(128, action_dim)

        # Critic网络：输出状态价值
        self.critic = nn.Linear(128, 1)

        self.feature_dim = len(
            feature_list.product_info[args.dataset]
            + feature_list.order_info[args.dataset]
            + feature_list.customer_info[args.dataset]
            + feature_list.shipping_info[args.dataset]
        )

    def forward(self, x):
        x = self.shared(x[0][: self.feature_dim].unsqueeze(0))
        return self.actor(x), self.critic(x)


def train_ppo(
    env,
    train_inputs,
    episodes=500,
    gamma=0.99,
    gae_lambda=0.95,
    eps_clip=0.2,
    lr=3e-4,
    epochs=10,
    batch_size=64,
):
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    model = ActorCritic(state_dim, action_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for episode in range(episodes):
        states = []
        actions = []
        log_probs = []
        rewards = []
        dones = []
        values = []

        state = env.reset()
        total_reward = 0

        for step in range(len(train_inputs)):
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                logits, value = model(state_tensor)
                action_probs = F.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(action_probs)
                action = dist.sample()
                log_prob = dist.log_prob(action)

            next_state, reward, done, _ = env.step(action.item())
            total_reward += reward

            # 存储经验
            states.append(state)
            actions.append(action.item())
            log_probs.append(log_prob.item())
            rewards.append(reward)
            dones.append(done)
            values.append(value.item())

            state = next_state
            if done:
                break

        # 将经验转为tensor
        states_tensor = torch.FloatTensor(np.array(states))
        actions_tensor = torch.LongTensor(actions)
        old_log_probs_tensor = torch.FloatTensor(log_probs)
        rewards_tensor = torch.FloatTensor(rewards)
        dones_tensor = torch.FloatTensor(dones)
        values_tensor = torch.FloatTensor(values)

        # 计算GAE和returns
        returns = torch.zeros_like(rewards_tensor)
        advantages = torch.zeros_like(rewards_tensor)
        last_gae = 0
        last_value = 0 if done else values_tensor[-1]

        for t in reversed(range(len(rewards))):
            delta = (
                rewards_tensor[t]
                + gamma * last_value * (1 - dones_tensor[t])
                - values_tensor[t]
            )
            advantages[t] = last_gae = (
                delta + gamma * gae_lambda * (1 - dones_tensor[t]) * last_gae
            )
            returns[t] = advantages[t] + values_tensor[t]
            last_value = values_tensor[t]

        # PPO 更新
        for _ in range(epochs):
            for i in range(0, len(states), batch_size):
                batch_end = i + batch_size
                batch_states = states_tensor[i:batch_end]
                batch_actions = actions_tensor[i:batch_end]
                batch_old_log_probs = old_log_probs_tensor[i:batch_end]
                batch_advantages = advantages[i:batch_end].unsqueeze(-1)

                logits, values_pred = model(batch_states)
                action_probs = F.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(action_probs)
                entropy = dist.entropy().mean()
                new_log_probs = dist.log_prob(batch_actions)

                # 计算比率
                ratio = torch.exp(new_log_probs - batch_old_log_probs)

                # PPO clip loss
                surr1 = ratio * batch_advantages
                surr2 = (
                    torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * batch_advantages
                )
                actor_loss = -torch.min(surr1, surr2).mean()

                # Critic loss
                critic_loss = F.mse_loss(
                    values_pred, returns[i:batch_end].unsqueeze(-1)
                )

                # 总loss
                loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward[0]:.2f}")

    return model


def parse_args():
    parser = argparse.ArgumentParser(description="AI4Simulation")

    # Device Setting
    parser.add_argument("--use_gpu", type=int, default=1)
    parser.add_argument("--device_id", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42)

    # Training Setting
    parser.add_argument(
        "--ckpt",
        type=str,
        default="./exp_report/OAS/ckpt/07-14-16_OAS_epoch17_sim_True_1.pth",
    )
    parser.add_argument("--ckpt_start_epoch", type=int, default=0)
    parser.add_argument(
        "--dataset",
        type=str,
        default="OAS",
        choices=["LSCRW", "DataCo", "GlobalStore", "OAS"],
    )
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--epochs", type=int, default=10000)
    parser.add_argument("--dm_epochs", type=int, default=6000)
    parser.add_argument("--eva_interval", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=2048)
    parser.add_argument("--early_stop", type=int, default=50)
    parser.add_argument(
        "--train_mode", type=int, default=2, help="1: both, 2: decision-maker only"
    )

    # Model Setting
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)

    # Regularizer coefficient
    parser.add_argument("--decay_coeff", type=float, default=5e-4)
    parser.add_argument("--dm_decay_coeff", type=float, default=5e-4)
    parser.add_argument("--mi_coeff", type=float, default=1)
    parser.add_argument("--ma_coeff", type=float, default=0)
    parser.add_argument("--otr_reward_coeff", type=float, default=10)
    parser.add_argument("--reward_smoothing_factor", type=float, default=0.1)
    parser.add_argument("--p_coeff", type=float, default=1)

    # Logger
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--epsilon_p", type=float, default=0.0)
    parser.add_argument("--random_noise", type=bool, default=False)
    return parser.parse_args()


if __name__ == "__main__":
    t = time.time()

    info("---------------------------- Env Init ------------------------------")
    args = parse_args()
    my_env = Env(args)

    info("---------------------------- Dataset Init --------------------------")
    my_loader = S_Loader(my_env)
    my_env.feature_classes = my_loader.feature_classes

    info("---------------------------- Model Init ----------------------------")
    my_model = S_SimDec(my_env)
    if args.ckpt:
        my_model.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
    v_model = ValueNetwork(my_env)

    info("---------------------------- Main ----------------------------------")
    cost_dic = my_loader.cost_mrp
    avg_profit = my_loader.avg_profit

    # 初始化强化学习环境
    env = ShippingEnv(
        my_model,
        my_loader,
        cost_dic,
        avg_profit,
        beta=0.3,
        random_noise=args.random_noise,
        epsilon_p=args.epsilon_p,
    )

    # # 使用 PPO 训练模型
    # trained_model = train_ppo(
    #     env, my_loader.train_inputs, episodes=15, gamma=0.99,
    #     gae_lambda=0.95, eps_clip=0.2, lr=3e-4, epochs=10, batch_size=64
    # )

    # torch.save(trained_model.state_dict(), f"./exp_report/{args.dataset}/ckpt/ppo_model.pth")
    # print("Model saved to ppo_model.pth")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    trained_model = ActorCritic(state_dim, action_dim)
    trained_model.load_state_dict(
        torch.load(
            f"./exp_report/{args.dataset}/ckpt/ppo_model.pth", map_location="cpu"
        )
    )

    perturbator = Perturbator(
        predictor=my_model,
        policy=v_model,
        env=my_env,
        M=getattr(my_env.args, "perturb_M", 8),
        device=my_env.device,
        cost_dic=cost_dic,
        avg_profit=my_loader.avg_profit,
        feature_list=feature_list,
        otr_reward_coeff=getattr(my_env.args, "otr_reward_coeff", 1.0),
        retrieve_index=feature_list.retrieve_index[my_env.args.dataset],
        action_dim=4,
    )
    evaluate_model(
        env,
        trained_model,
        my_loader.test_inputs,
        # random_noise=args.random_noise,
        # epsilon_p=args.epsilon_p,
    )

    print("Total Time Cost:", time.time() - t)
