import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
from stable_baselines3.common.utils import set_random_seed
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate
import random
import time
import torch
import os



# —— 锁线程数，防止并行库抢核 —— #
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
# （可选）让 PyTorch 在 CPU 上只用 1 线程
torch.set_num_threads(1)
torch.set_num_interop_threads(1)




# ============== 实用函数：每次新模型训练时调用，统一设定随机性 ==============
def seed_everything(seed: int):
    set_random_seed(seed)            # SB3 内部 + numpy/random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

# ==== Step 1: 扩展 SB3 policy 以便拿到动作概率 ====
from stable_baselines3.ppo.policies import MlpPolicy
import torch

def get_action_prob(self, obs):
    if isinstance(obs, np.ndarray) and obs.ndim == 1:
        obs = obs[np.newaxis, :]
    obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
    with torch.no_grad():
        features = self.extract_features(obs_tensor)
        latent_pi, _ = self.mlp_extractor(features)
        dist = self._get_action_dist_from_latent(latent_pi)
        return dist.distribution.probs.detach().cpu().numpy().squeeze()

MlpPolicy.get_action_prob = get_action_prob

# ==== Step 2: Population Entropy 奖励 ====
def compute_population_entropy_bonus(state, action, population, alpha):
    probs = []
    for agent in population:
        try:
            p = agent.policy.get_action_prob(state)
            probs.append(p[action])
        except Exception as e:
            print(f"⚠️ Error computing PE bonus: {e}")
            continue
    if len(probs) == 0:
        return 0.0
    avg_prob = float(np.mean(probs))
    return -alpha * np.log(avg_prob + 1e-8)

# ==== Step 3:（可选）基于难度的优先采样 ====
def prioritized_sampling(agent_A, population, eval_env, beta=1.0):
    returns = []
    for partner in population:
        total_reward = 0.0
        obs = eval_env.reset()
        done = False
        while not done:
            a_action, _ = agent_A.predict(obs[0], deterministic=True)
            b_action, _ = partner.predict(obs[1], deterministic=True)
            low_level_actions = eval_env._computeLowLevelActions([a_action, b_action])
            obs, reward, done, _ = eval_env.step(low_level_actions)
            eval_env._checkMacroActionDone()
            eval_env._checkCollision(_)
            obs = eval_env._get_macro_obs()
            total_reward += float(reward[0])  # 关注 A 的回报
        returns.append(total_reward)

    difficulty = [1.0 / (r + 1e-6) for r in returns]
    sorted_idx = np.argsort(difficulty)[::-1]
    ranks = np.empty_like(sorted_idx)
    ranks[sorted_idx] = np.arange(1, len(difficulty) + 1)
    weights = ranks.astype(np.float32) ** beta
    probs = weights / np.sum(weights)
    selected_idx = np.random.choice(len(population), p=probs)
    return population[selected_idx]

# ====== 任务/动作名等 ======
ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {
    "stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3,
    "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7,
    "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11,
    "right": 12, "down": 13, "left": 14, "up": 15
}

def find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3):
    reachable_indices, distances = [], []
    if can_reach_1 != 4:
        reachable_indices.append(0); distances.append(distance_1)
    if can_reach_2 != 4:
        reachable_indices.append(1); distances.append(distance_2)
    if can_reach_3 != 4:
        reachable_indices.append(2); distances.append(distance_3)
    if not reachable_indices:
        return False
    if len(reachable_indices) == 1:
        return reachable_indices[0]
    return reachable_indices[distances.index(min(distances))]

def intelligently_find_item_number(env, agent_item, raw_name):
    if raw_name == "get plate":
        target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
        can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
        distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

        target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
        can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
        distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

        target_x_3, target_y_3 = env._findPOitem(agent_item, macroActionDict["get dirty plate"])
        can_reach_3 = env._navigate(agent_item, target_x_3, target_y_3)
        distance_3 = env._calDistance(target_x_3, target_y_3, agent_item.x, agent_item.y)

        best_action = "stay"
        idx = find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3)
        if idx == 0: best_action = raw_name + " 1"
        if idx == 1: best_action = raw_name + " 2"
        if idx == 2: best_action = "get dirty plate"
        return best_action

    target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
    can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
    distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

    target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
    can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
    distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

    best_action = "stay"
    if can_reach_1 == 4 and can_reach_2 != 4: best_action = raw_name + " 2"
    if can_reach_1 != 4 and can_reach_2 == 4: best_action = raw_name + " 1"
    if can_reach_1 != 4 and can_reach_2 != 4:
        best_action = raw_name + " 1" if distance_1 <= distance_2 else raw_name + " 2"
    return best_action

def check_benevolence(env, best_action, action):
    env.reward = 0
    if action == macroActionDict[best_action] and macroActionDict[best_action] != 0:
        env.reward += 20
    return env.reward

def check_action_benevolence(env, action_left, action_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter):
    agent_item = env.agent[1]
    human_agent = env.agent[0]
    counter_coords = [(4, y) for y in range(2, 14)]
    counters = [ITEMNAME[env.map[x][y]] for (x, y) in counter_coords]
    reward_shaping_bonus = 0

    if any(counter in ("lettuce") for counter in counters):
        best_action = intelligently_find_item_number(env, human_agent, "get lettuce")
        if firsttime_left_get_lettuce:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 20:
                firsttime_left_get_lettuce = False

    if all(counter not in ("lettuce") for counter in counters):
        if agent_item.holding and isinstance(agent_item.holding, Lettuce):
            best_action = "go to counter"
            if firsttime_right_go_to_counter:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 20:
                    firsttime_right_go_to_counter = False

        best_action = "go to counter"
        if firsttime_left_go_to_counter:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 20:
                firsttime_left_go_to_counter = False

    return reward_shaping_bonus, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter

class SingleAgentWrapper(gym.Wrapper):
    """ 从多智能体环境中抽取单智能体视角 """
    def __init__(self, env, agent_index, other_agent_model=None, population=None, alpha=0.01):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True
        self.obs = None
        self.population = population
        self.alpha = alpha

    def reset(self, **kwargs):
        # 兼容不同 Gym 版本
        try:
            self.obs = self.env.reset(**kwargs)
        except TypeError:
            self.obs = self.env.reset()
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True
        return self.obs[self.agent_index]

    def step(self, action):
        actions = [0, 0]
        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])
        actions[self.agent_index] = action
        actions[1 - self.agent_index] = other_agent_action[0]

        if self.population:
            pe_bonus = compute_population_entropy_bonus(self.obs[self.agent_index], action, self.population, self.alpha)
        else:
            pe_bonus = 0.0

        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)
        self.obs, rewards, dones, info = self.env.step(primary_actions)
        self.obs = self.env._get_macro_obs()

        team_reward = rewards[self.agent_index] + rewards[1 - self.agent_index] + pe_bonus
        return self.obs[self.agent_index], team_reward, dones, info

class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0
        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += float(self.locals['rewards'][0])

        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) /
                              (i - max(0, i - window + 1) + 1) for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()
        return True

# ====== 奖励与环境参数 ======
rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
}]

mac_env_id = 'Overcooked-MA-v2'
env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "A_lowuncertainty",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': True
}

ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 256,
    'batch_size': 128,
    'n_epochs': 10,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}

policy_kwargs = dict(
    net_arch=[dict(pi=[256, 128, 64], vf=[256, 128, 64])]
)

# ====== 构建初始 population（每个模型用自己的 seed 与环境初始化） ======
population = []
seeds = [15, 25, 35, 45, 55]

for s in seeds:
    seed_everything(s)
    # 为该模型创建一个临时环境并用 seed 初始化
    env_tmp = gym.make(mac_env_id, **env_params)
    try:
        env_tmp.reset(seed=s)
    except TypeError:
        env_tmp.seed(s)
        env_tmp.reset()
    env_tmp.action_space.seed(s)
    env_tmp.observation_space.seed(s)

    env_agent_0_tmp = SingleAgentWrapper(env_tmp, agent_index=0)
    model = PPO("MlpPolicy", env_agent_0_tmp, policy_kwargs=policy_kwargs, seed=s, **ppo_params)
    population.append(model)

# （可选调试）看一下第一个模型在一个新 env 上的动作分布
tmp_env = gym.make(mac_env_id, **env_params)
try:
    tmp_env.reset(seed=seeds[0])
except TypeError:
    tmp_env.seed(seeds[0]); tmp_env.reset()
tmp_env.action_space.seed(seeds[0]); tmp_env.observation_space.seed(seeds[0])
tmp_env_agent0 = SingleAgentWrapper(tmp_env, agent_index=0)
obs = tmp_env_agent0.reset()
print("obs shape:", obs.shape)
print(population[0].policy.get_action_prob(obs))

# ====== 训练配置 ======
total_alternate_steps = 3_000_000
alternate_interval   = 10_000

def format_time(seconds):
    minutes = int(seconds // 60)
    secs = int(seconds % 60)
    return f"{minutes}分{secs}秒"

# ====== 针对 population 中的每个模型：重置 seed，并用该 seed 重建训练环境 ======
global_start_time = time.time()

for agent_idx, agent_to_train in enumerate(population):
    s = seeds[agent_idx]
    print(f"\n============= Start training model #{agent_idx} with seed {s}")

    # A) 本轮训练用的随机性
    seed_everything(s)

    # B) 本轮训练用的新环境（与模型 seed 一致）
    shared_env = gym.make(mac_env_id, **env_params)
    try:
        shared_env.reset(seed=s)
    except TypeError:
        shared_env.seed(s)
        shared_env.reset()
    shared_env.action_space.seed(s)
    shared_env.observation_space.seed(s)

    # C) 本轮日志与回调
    logger = configure(f'./logs/seed_{s}/', ["csv", "tensorboard"])
    agent_to_train.set_logger(logger)
    this_agent_callback = EpisodeRewardCallback(f'final_trained_models/[MapA_lowuncertainty]MEP_seed{s}_teamReward')

    # D) 交替训练（同一轮内部不切 seed）
    for i in range(0, total_alternate_steps, alternate_interval):
        print(f"Training Agent 0 (Steps {i} → {i + alternate_interval})")
        env_agent_0 = SingleAgentWrapper(
            shared_env,
            agent_index=0,
            other_agent_model=agent_to_train,
            population=population,
            alpha=1.0
        )
        agent_to_train.set_env(env_agent_0)
        agent_to_train.learn(total_timesteps=alternate_interval, callback=this_agent_callback)

        print(f"Training Agent 1 (Steps {i} → {i + alternate_interval})")
        env_agent_1 = SingleAgentWrapper(
            shared_env,
            agent_index=1,
            other_agent_model=agent_to_train,
            population=population,
            alpha=1.0
        )
        agent_to_train.set_env(env_agent_1)
        agent_to_train.learn(total_timesteps=alternate_interval, callback=this_agent_callback)

        elapsed = time.time() - global_start_time
        print(f"[🕒 累计训练时间] {format_time(elapsed)}")
