import random
import numpy as np


class EnvWrapper:
    def __init__(self, env, seed=None):
        self.env = env
        self.seed = seed

        env.reset(seed=0)
        self.agent_id = env.agents
        self.agent_num = len(self.agent_id)

        action_dim_list = [env.action_space(agt).shape[0] for agt in env.agents]
        if all(x == action_dim_list[0] for x in action_dim_list):
            self.action_dim = action_dim_list[0]
        else:
            raise ValueError("Action dimensions are not equal across agents.")

        obs_dim_list = [env.observation_space(agt).shape[0] for agt in env.agents]
        if all(x == obs_dim_list[0] for x in obs_dim_list):
            self.obs_dim = obs_dim_list[0]
        else:
            raise ValueError("Obs dimensions are not equal across agents.")

    def reset(self):
        obs_dict, _ = self.env.reset(seed=self.seed)
        obs = [obs_dict[agt] for agt in self.agent_id]
        return np.array(obs)

    def step(self, action_np):
        action_dict = {agt: action_np[i] for i, agt in enumerate(self.agent_id)}
        obs_dict, rew_dict, terminated_dict, truncated_dict, _ = self.env.step(action_dict)

        obs = self.dict2np(obs_dict)
        rew = self.dict2np(rew_dict)
        terminated = self.dict2np(terminated_dict)
        truncated = self.dict2np(truncated_dict)

        if np.all(rew == rew[0]):
            rew = rew[0]
        else:
            raise ValueError("Error: rew contains different values!")

        if np.any(terminated) != np.all(terminated):
            print("Error: terminated contains different values!")

        if np.any(terminated):
            terminated = True
        else:
            terminated = False

        if np.any(truncated) != np.all(truncated):
            print("Error: truncated contains different values!")

        if np.any(truncated):
            truncated = True
        else:
            truncated = False

        return obs, rew, terminated, truncated, _

    def dict2np(self, item_dict):
        item = [item_dict[agt] for agt in self.agent_id]
        return np.array(item)

    def np2dict(self, item_np):
        dict = {agt: item_np[i] for i, agt in enumerate(self.agent_id)}
        return dict

    def close(self):
        self.env.close()
