from typing import Callable

import gym
import numpy as np

from gym import spaces


class SynchronousEnvParallelizer(gym.Env):
    """
    A meta-environment that allows to "parallelize" synchronous environments with observation shape
    List[(1, *agent_observation_shape)] by running reset/step sequentially and concatenating observations along axis 1.

    (synchronous = done at exactly the same timesteps, regardless of the actions)
    """
    def __init__(self, make_environment_function: Callable[[int], gym.Env], num_envs: int):
        self.envs = [make_environment_function(i) for i in range(0, num_envs)]
        self.num_agents = self.envs[0].env.num_agents
        self.num_actions = self.envs[0].env.num_actions
        self.num_envs = num_envs
        self.epoch = None

        assert self.envs[0].observation_space[0].shape[0] == 1, \
            f"Invalid observation space shape {self.envs[0].observation_space[0]}"

        single_obs_space = self.envs[0].observation_space[0]
        new_obs_space = spaces.Box(
            low=single_obs_space.low.min(), high=single_obs_space.high.max(),
            shape=(num_envs, *single_obs_space.shape[1:]), dtype=single_obs_space.dtype
        )
        self.observation_space = [new_obs_space for _ in range(0, self.num_agents)]

    def step(self, action):
        all_obs = np.zeros((self.num_agents, *self.observation_space[0].shape), dtype=float)
        all_reward = np.zeros((self.num_agents, self.num_envs), dtype=float)
        all_alive_mask = None
        all_agent_done = None
        all_success = None
        all_done = None

        for i, env in enumerate(self.envs):
            all_obs[:, [i]], all_reward[:, [i]], done, info = env.step(action[:, i])

            if 'alive_mask' in info:
                if all_alive_mask is None:
                    all_alive_mask = np.zeros((self.num_agents, self.num_envs), dtype=float)

                all_alive_mask[:, [i]] = info['alive_mask']

            if 'agent_done' in info:
                if all_agent_done is None:
                    all_agent_done = np.zeros((self.num_agents, self.num_envs), dtype=bool)

                all_agent_done[:, [i]] = info['agent_done']

            if 'success' in info:
                if all_success is None:
                    all_success = np.zeros(self.num_envs, dtype=bool)

                all_success[i] = info['success']

            if all_done is None:
                all_done = done
            else:
                assert all_done == done, f"Environments are not synchronous. Expected done={all_done} but got {done}"

        all_info = {}
        if all_alive_mask is not None:
            assert all_agent_done is not None, "Individual dones are required if agents can be alive in selected steps"
            all_info['alive_mask'] = all_alive_mask
            all_info['agent_done'] = all_agent_done

        if all_success is not None:
            all_info['success'] = all_success

        return all_obs, all_reward, all_done, all_info

    def reset(self):
        all_obs = np.zeros((self.num_agents, *self.observation_space[0].shape), dtype=float)

        for i, env in enumerate(self.envs):
            # forward epoch
            if hasattr(env.env, 'epoch'):
                env.env.epoch = self.epoch

            all_obs[:, [i]] = env.reset()

        return all_obs

    def render(self, mode='human'):
        pass
