import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate, BadLettuce
import random
import time
import torch
import os

# ====== 全局随机种子 ======
SEED = 42  # 你可以修改这个数字来改变随机性

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)




# ==== Step 1: Extend SB3 policy to expose action probabilities ====
from stable_baselines3.ppo.policies import MlpPolicy
import torch



def get_action_prob(self, obs):
    if obs.ndim == 1:
        obs = obs[np.newaxis, :]
    obs_tensor = torch.tensor(obs, dtype=torch.float32).to(self.device)

    with torch.no_grad():
        features = self.extract_features(obs_tensor)             # 👈 抽取原始特征
        latent_pi, _ = self.mlp_extractor(features)              # 👈 通过策略网络映射
        dist = self._get_action_dist_from_latent(latent_pi)      # 👈 得到策略分布
        return dist.distribution.probs.cpu().numpy().squeeze()

MlpPolicy.get_action_prob = get_action_prob

# ==== Step 2: Define PE bonus calculation ====
def compute_population_entropy_bonus(state, action, population, alpha):
    # print("🧪 state type:", type(state))
    # print("🧪 state content:", state)

    probs = []
    for agent in population:
        try:
            p = agent.policy.get_action_prob(state)
            # print("action prob dist:", p[action])
            probs.append(p[action])
        except Exception as e:
            print(f"⚠️ Error computing PE bonus: {e}")
            continue
    if len(probs) == 0:
        return 0.0
    avg_prob = np.mean(probs)
    
    return -alpha * np.log(avg_prob + 1e-8)
    # bonus = - 100 * (avg_prob - 0.5)  # 与uniform相比偏离越大，惩罚越大
    return bonus


def compute_kl_divergence_bonus(state, current_policy, population, alpha=0.1):
    p_current = current_policy.get_action_prob(state)

    probs_list = []
    for agent in population:
        try:
            p = agent.policy.get_action_prob(state)
            probs_list.append(p)
        except Exception as e:
            continue

    if not probs_list:
        return 0.0

    p_mean = np.mean(probs_list, axis=0)
    kl_div = np.sum(p_current * (np.log(p_current + 1e-8) - np.log(p_mean + 1e-8)))
    return alpha * kl_div





ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {"stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3, "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7, "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11, "right": 12, "down": 13, "left": 14, "up": 15}


def find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3):
    reachable_indices = []
    distances = []
    
    # 收集可达的索引和对应距离
    if can_reach_1 != 4:
        reachable_indices.append(0)
        distances.append(distance_1)
    if can_reach_2 != 4:
        reachable_indices.append(1)
        distances.append(distance_2)
    if can_reach_3 != 4:
        reachable_indices.append(2)
        distances.append(distance_3)
    
    if not reachable_indices:
        return False  # 没有可达的位置
    
    # 如果只有一个可达位置，直接返回该索引
    if len(reachable_indices) == 1:
        return reachable_indices[0]
    
    # 否则，返回距离最小的索引
    min_distance_index = reachable_indices[distances.index(min(distances))]
    return min_distance_index


def intelligently_find_item_number(env, agent_item, raw_name):

    if raw_name == "get plate":
        target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
        can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
        distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

        target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
        can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
        distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

        target_x_3, target_y_3 = env._findPOitem(agent_item, macroActionDict["get dirty plate"])
        can_reach_3 = env._navigate(agent_item, target_x_3, target_y_3)
        distance_3 = env._calDistance(target_x_3, target_y_3, agent_item.x, agent_item.y)


        best_action = "stay"

        min_distance_index = find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3)

        if min_distance_index == 0:
            best_action = raw_name + " 1"
        if min_distance_index == 1:
            best_action = raw_name + " 2"
        if min_distance_index == 2:
            best_action = "get dirty plate"
        # print('----', best_action)
        return best_action

    

    target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
    can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
    distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

    target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
    can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
    distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

    best_action = "stay"
    if can_reach_1 == 4 and can_reach_2 != 4:
        best_action = raw_name + " 2"

    if can_reach_1 != 4 and can_reach_2 == 4:
        best_action = raw_name + " 1"

    if can_reach_1 != 4 and can_reach_2 != 4:
        if distance_1 <= distance_2:
            best_action = raw_name + " 1"

        else:
            best_action = raw_name + " 2"
    
    return best_action



def check_benevolence(env, best_action, action):
    env.reward = 0
    if action == macroActionDict[best_action] and macroActionDict[best_action] != 0:
        env.reward += 20
    return env.reward



def check_action_benevolence(env, action_left, action_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter):


    agent_item = env.agent[1]
    human_agent = env.agent[0]


    counter1_x = 4
    counter1_y = 2

    counter2_x = 4
    counter2_y = 3

    counter3_x = 4
    counter3_y = 4

    counter4_x = 4
    counter4_y = 5

    counter5_x = 4
    counter5_y = 6

    counter6_x = 4
    counter6_y = 7

    counter7_x = 4
    counter7_y = 8
    
    counter8_x = 4
    counter8_y = 9

    counter9_x = 4
    counter9_y = 10

    counter10_x = 4
    counter10_y = 11

    counter11_x = 4
    counter11_y = 12

    counter12_x = 4
    counter12_y = 13


    counter1 = ITEMNAME[env.map[counter1_x][counter1_y]]
    counter2 = ITEMNAME[env.map[counter2_x][counter2_y]]
    counter3 = ITEMNAME[env.map[counter3_x][counter3_y]]
    counter4 = ITEMNAME[env.map[counter4_x][counter4_y]]
    counter5 = ITEMNAME[env.map[counter5_x][counter5_y]]
    counter6 = ITEMNAME[env.map[counter6_x][counter6_y]]
    counter7 = ITEMNAME[env.map[counter7_x][counter7_y]]
    counter8 = ITEMNAME[env.map[counter8_x][counter8_y]]
    counter9 = ITEMNAME[env.map[counter9_x][counter9_y]]
    counter10 = ITEMNAME[env.map[counter10_x][counter10_y]]
    counter11 = ITEMNAME[env.map[counter11_x][counter11_y]]
    counter12 = ITEMNAME[env.map[counter12_x][counter12_y]]



    reward_shaping_bonus = 0

    reward_bonus_left = 0
    reward_bonus_right = 0


    """右侧high benevolence"""
    counters = [counter1, counter2, counter3, counter4, counter5, counter6, counter7, counter8, counter9, counter10, counter11, counter12]
    if any(counter in ("badlettuce") for counter in counters):
        best_action = "get badlettuce"

        if firsttime_left_get_lettuce == True:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 20:
                reward_bonus_left = 50
                firsttime_left_get_lettuce = False


    if all(counter not in ("badlettuce") for counter in counters):

        if not agent_item.holding:
            best_action = "get badlettuce"
            if firsttime_right_go_to_knife == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 20:
                    reward_bonus_right = 100
                    firsttime_right_go_to_knife = False


        if agent_item.holding and isinstance(agent_item.holding, BadLettuce):
            best_action = "go to counter"

            if firsttime_right_go_to_counter == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 20:
                    reward_bonus_right = 50
                    firsttime_right_go_to_counter = False

        best_action = "go to counter"
        if firsttime_left_go_to_counter == True:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 20:
                reward_bonus_left = 50
                firsttime_left_go_to_counter = False



    return reward_bonus_left, reward_bonus_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter




class SingleAgentWrapper(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


    def reset(self):
        self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        return self.obs[self.agent_index]

    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]


        """下面代码是进行【合作】的reward shaping，以及对inconsistency进行惩罚的部分，暂时注释掉"""
        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]

        if self.agent_index == 0:
            action_left = action
            action_right = other_agent_action[0]



        # benevolence_reward_left, benevolence_reward_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)


        benevolence_reward_left, benevolence_reward_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, real_execute_macro_actions[0], real_execute_macro_actions[1], self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()

        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward_left, dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[1 - self.agent_index] + benevolence_reward_right, dones, info



class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0

        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += self.locals['rewards'][0]

        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        # Save and plot every save_freq steps
        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            # Plot moving average of last 100 episode rewards
            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) / (i - max(0, i - window + 1) + 1)
                              for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()

        return True



rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 20,
    "pick up bad lettuce": 0
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 20,
    "pick up bad lettuce": 0
}]


mac_env_id = 'Overcooked-MA-v0'
env_params = {
    'grid_dim': [15, 15],
    'task': ["badlettuce salad"],
    'rewardList': rewardList,
    'map_type': "A",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': True
}


new_logger = configure('./logs/', ["csv", "tensorboard"])  # Remove "stdout" to prevent console logging


# Initialize shared environment
shared_env = gym.make(mac_env_id, **env_params)
shared_env.seed(SEED)
shared_env.action_space.seed(SEED)
shared_env.observation_space.seed(SEED)

print('=====================', shared_env._findItem(9, 10, "tomato"))

# Wrap each agent
env_agent_0 = SingleAgentWrapper(shared_env, agent_index=0)
env_agent_1 = SingleAgentWrapper(shared_env, agent_index=1)

ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 256,
    'batch_size': 128,
    'n_epochs': 10,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    # 'ent_coef': 0.05,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}


policy_kwargs = dict(
    net_arch=[dict(pi=[256, 128, 64], vf=[256, 128, 64])]
)


model_agent_0 = PPO(
    "MlpPolicy",
    env_agent_0,
    verbose=1,
    policy_kwargs=policy_kwargs,
    seed=SEED
)

model_agent_1 = PPO(
    "MlpPolicy",
    env_agent_1,
    verbose=1,
    policy_kwargs=policy_kwargs,
    seed=SEED
)


model_agent_0.set_logger(new_logger)
model_agent_1.set_logger(new_logger)

reward_callback_0 = EpisodeRewardCallback('final_trained_models/[MapA]trustee_agent_highB_lowI_agent0')
reward_callback_1 = EpisodeRewardCallback('final_trained_models/[MapA]trustee_agent_highB_lowI_agent1')


# Training configuration
total_alternate_steps = 5000000  # Total training steps
alternate_interval = 10000  # Each agent trains for this many steps before switching


global_start_time = time.time()  # 记录整个训练开始的时间

def format_time(seconds):
    minutes = int(seconds // 60)
    secs = int(seconds % 60)
    return f"{minutes}分{secs}秒"




# Alternate training loop
for i in range(0, total_alternate_steps, alternate_interval):
    print(f"Training Agent 0 (Steps {i} to {i + alternate_interval})")
    env_agent_0 = SingleAgentWrapper(shared_env, agent_index=0, other_agent_model=model_agent_1)  # Agent 1 is fixed
    model_agent_0.set_env(env_agent_0)
    model_agent_0.learn(total_timesteps=alternate_interval, callback=reward_callback_0)

    print(f"Training Agent 1 (Steps {i} to {i + alternate_interval})")
    env_agent_1 = SingleAgentWrapper(shared_env, agent_index=1, other_agent_model=model_agent_0)  # Agent 0 is fixed
    model_agent_1.set_env(env_agent_1)
    model_agent_1.learn(total_timesteps=alternate_interval, callback=reward_callback_1)

    phase_end_time = time.time()
    total_duration = phase_end_time - global_start_time
    print(f"[🕒 累计训练时间] {format_time(total_duration)}")

