import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
from stable_baselines3.common.utils import set_random_seed
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate
import random
import time
import torch
import os



os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

torch.set_num_threads(1)
torch.set_num_interop_threads(1)



ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {
    "stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3,
    "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7,
    "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11,
    "right": 12, "down": 13, "left": 14, "up": 15
}



class SingleAgentWrapper(gym.Wrapper):
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None

    def reset(self, **kwargs):
        if "seed" in kwargs:
            self.obs = self.env.reset(seed=kwargs["seed"])
        else:
            self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        return self.obs[self.agent_index]

    def step(self, action):
        benevolence_reward = 0
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action
        actions[1 - self.agent_index] = other_agent_action[0]

        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]
        else:
            action_left = action
            action_right = other_agent_action[0]


        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)
        self.obs, rewards, dones, info = self.env.step(primary_actions)
        self.obs = self.env._get_macro_obs()

        return self.obs[self.agent_index], rewards[self.agent_index] + rewards[1 - self.agent_index], dones, info

class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0
        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += self.locals['rewards'][0]
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) /
                              (i - max(0, i - window + 1) + 1) for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()
        return True


rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
}]

mac_env_id = 'Overcooked-MA-v1'
env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "B",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': True
}

ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 256,
    'batch_size': 128,
    'n_epochs': 10,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}

policy_kwargs = dict(
    net_arch=[dict(pi=[256, 128, 64], vf=[256, 128, 64])]
)


total_alternate_steps = 3_000_000   
alternate_interval   = 10_000       

def format_time(seconds):
    minutes = int(seconds // 60)
    secs = int(seconds % 60)
    return f"{minutes}minute{secs}second"


for seed_number in [15, 25, 35, 45, 55]:
    print('============= training with seed', seed_number)


    set_random_seed(seed_number)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_number)
    torch.manual_seed(seed_number)
    np.random.seed(seed_number)
    random.seed(seed_number)


    shared_env = gym.make(mac_env_id, **env_params)

    try:
        shared_env.reset(seed=seed_number)
    except TypeError:
        shared_env.seed(seed_number)
        shared_env.reset()
    shared_env.action_space.seed(seed_number)
    shared_env.observation_space.seed(seed_number)


    try:
        print('=====================', shared_env._findItem(9, 10, "tomato"))
    except Exception as e:
        print("Debug _findItem skipped:", e)


    env_agent_0 = SingleAgentWrapper(shared_env, agent_index=0)
    env_agent_1 = SingleAgentWrapper(shared_env, agent_index=1)


    model_agent_0 = PPO(
        "MlpPolicy",
        env_agent_0,
        seed=seed_number,
        policy_kwargs=policy_kwargs,
        **ppo_params
    )
    model_agent_0.set_logger(configure(f'./logs/seed_{seed_number}/', ["csv", "tensorboard"]))
    reward_callback_0 = EpisodeRewardCallback(f'final_trained_models/[MapB]SP_agent_teamReward_seed{seed_number}')


    global_start_time = time.time()
    for i in range(0, total_alternate_steps, alternate_interval):
        print(f"Training Agent 0 (Steps {i} to {i + alternate_interval})")
        env_agent_0 = SingleAgentWrapper(shared_env, agent_index=0, other_agent_model=model_agent_0)
        model_agent_0.set_env(env_agent_0)
        model_agent_0.learn(total_timesteps=alternate_interval, callback=reward_callback_0)

        print(f"Training Agent 1 (Steps {i} to {i + alternate_interval})")
        env_agent_1 = SingleAgentWrapper(shared_env, agent_index=1, other_agent_model=model_agent_0)
        model_agent_0.set_env(env_agent_1)
        model_agent_0.learn(total_timesteps=alternate_interval, callback=reward_callback_0)

        phase_end_time = time.time()
        total_duration = phase_end_time - global_start_time
        print(f"[Training time:] {format_time(total_duration)}")
