import numpy as np
import gymnasium_robotics
import threading
from queue import Queue


class ThreadedEnvWrapper:
    def __init__(self, env_fn, num_envs, seed=None):
        self.num_envs = num_envs
        self.envs = [env_fn() for _ in range(num_envs)]
        self.seed = seed

        # Initialize all environments
        for env in self.envs:
            env.reset(seed=seed)

        # Get agent and observation dimensions from the first environment
        self.agent_id = self.envs[0].agents
        self.agent_num = len(self.agent_id)

        # Check if action and observation dimensions are consistent across environments
        action_dim_list = [env.action_space(agt).shape[0] for agt in self.agent_id]
        if not all(x == action_dim_list[0] for x in action_dim_list):
            raise ValueError("Action dimensions are not equal across agents.")
        self.action_dim = action_dim_list[0]

        obs_dim_list = [env.observation_space(agt).shape[0] for agt in self.agent_id]
        if not all(x == obs_dim_list[0] for x in obs_dim_list):
            raise ValueError("Observation dimensions are not equal across agents.")
        self.obs_dim = obs_dim_list[0]

        # Queues for thread communication
        self.obs_queue = Queue()
        self.rew_queue = Queue()
        self.terminated_queue = Queue()
        self.truncated_queue = Queue()

    def reset(self):
        # Reset all environments in parallel using threads
        threads = []
        for env in self.envs:
            thread = threading.Thread(target=self._reset_env, args=(env,))
            thread.start()
            threads.append(thread)

        # Wait for all threads to finish
        for thread in threads:
            thread.join()

        # Collect results from the queue
        obs_list = []
        while not self.obs_queue.empty():
            obs_list.append(self.obs_queue.get())

        obs = np.stack(obs_list, axis=0)  # [num_envs, agent_num, obs_dim]
        return obs

    def step(self, actions):
        # actions: [num_envs, agent_num, action_dim]
        assert actions.shape == (self.num_envs, self.agent_num, self.action_dim)

        # Step all environments in parallel using threads
        threads = []
        for i, env in enumerate(self.envs):
            action_dict = {agt: actions[i, j] for j, agt in enumerate(self.agent_id)}
            thread = threading.Thread(target=self._step_env, args=(env, action_dict))
            thread.start()
            threads.append(thread)

        # Wait for all threads to finish
        for thread in threads:
            thread.join()

        # Collect results from the queues
        obs_list, rew_list, terminated_list, truncated_list = [], [], [], []
        while not self.obs_queue.empty():
            obs_list.append(self.obs_queue.get())
            rew_list.append(self.rew_queue.get())
            terminated_list.append(self.terminated_queue.get())
            truncated_list.append(self.truncated_queue.get())

        # Stack the results
        obs = np.stack(obs_list, axis=0)  # [num_envs, agent_num, obs_dim]
        rew = np.stack(rew_list, axis=0)  # [num_envs, agent_num]
        terminated = np.stack(terminated_list, axis=0)  # [num_envs, agent_num]
        truncated = np.stack(truncated_list, axis=0)  # [num_envs, agent_num]

        # Check if rewards, terminated, and truncated are consistent across agents
        if not np.all(rew == rew[:, [0]], axis=1).all():
            raise ValueError("Error: rew contains different values across agents!")

        if not np.all(terminated == terminated[:, [0]], axis=1).all():
            print("Warning: terminated contains different values across agents!")

        if not np.all(truncated == truncated[:, [0]], axis=1).all():
            print("Warning: truncated contains different values across agents!")

        # Reduce dimensions if all agents have the same values
        rew = rew[:, 0]  # [num_envs]
        terminated = np.any(terminated, axis=1)  # [num_envs]
        truncated = np.any(truncated, axis=1)  # [num_envs]

        return obs, rew, terminated, truncated, {}

    def _reset_env(self, env):
        obs_dict, _ = env.reset(seed=self.seed)
        obs = self.dict2np(obs_dict)
        self.obs_queue.put(obs)

    def _step_env(self, env, action_dict):
        obs_dict, rew_dict, terminated_dict, truncated_dict, _ = env.step(action_dict)
        obs = self.dict2np(obs_dict)
        rew = self.dict2np(rew_dict)
        terminated = self.dict2np(terminated_dict)
        truncated = self.dict2np(truncated_dict)

        self.obs_queue.put(obs)
        self.rew_queue.put(rew)
        self.terminated_queue.put(terminated)
        self.truncated_queue.put(truncated)

    def dict2np(self, item_dict):
        item = [item_dict[agt] for agt in self.agent_id]
        return np.array(item)

    def np2dict(self, item_np):
        # item: [agent_num, _dim]
        dict = {agt: item_np[i] for i, agt in enumerate(self.agent_id)}
        return dict