from __future__ import annotations
from dataclasses import astuple, dataclass
from enum import Enum
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple

import numpy as np
import torch


Info = Tuple[Dict[str, Any]]
ResetOutput = Tuple[torch.Tensor, Info]
StepOutput = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Info]


class Env:
    def __init__(self, env_fn: Callable, num_envs: int = 1, device: Optional[torch.device] = None) -> None:
        self.device = device if device is not None else torch.device('cpu')
        self._num_envs = num_envs
        self.main_env = MainProcessEnv(env_fn)
        self.mp_env = MultiProcessEnv(env_fn, num_envs - 1) if num_envs > 1 else None
        self.num_actions = self.main_env._env.action_space.n
        self.done_tracker = DoneTracker(num_envs, self.device)
    
    @property
    def num_envs(self) -> int:
        return self._num_envs

    @property
    def all_done(self) -> bool:
        return self.done_tracker.all_done
    
    @property
    def is_alive(self) -> bool:
        return self.done_tracker.is_alive
    
    def reset(self) -> ResetOutput:
        self.done_tracker.reset()
        obs, info = self.main_env.reset()
        if self.mp_env is not None:
            obs_mp, info_mp = self.mp_env.reset()
            obs, info = np.concatenate((obs, obs_mp)), info + info_mp
        return self._to_tensor(obs), info

    def step(self, actions: torch.LongTensor) -> StepOutput:
        *o_r_d_t, info = self.main_env.step(actions[0].item())  # o_r_d_t: obs, rew, done, trunc
        if self.mp_env is not None:
            *o_r_d_t_mp, info_mp = self.mp_env.step(actions[1:].cpu().numpy())
            o_r_d_t, info = (np.concatenate((x, y)) for x, y in zip(o_r_d_t, o_r_d_t_mp)), info + info_mp
        obs, rew, done, trunc = (self._to_tensor(x) for x in o_r_d_t)
        self.done_tracker.update(done, trunc)
        return obs, rew, done, trunc, info

    def _to_tensor(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 4: return torch.tensor(x, device=self.device).div(255).mul(2).sub(1).permute(0, 3, 1, 2).contiguous()
        elif x.dtype is np.dtype('bool'): return torch.tensor(x, dtype=torch.uint8, device=self.device)
        else: return torch.tensor(x, dtype=torch.float32, device=self.device)


class DoneTracker:
    """Monitor env dones: 0 when not done, 1 when done, 2 when already done."""
    def __init__(self, num_envs: int, device: torch.device) -> None:
        self._tracker = torch.zeros(num_envs, dtype=torch.uint8, device=device)

    def reset(self) -> None:
        self._tracker = torch.zeros_like(self._tracker)

    def update(self, done: torch.Tensor, truncated: torch.Tensor) -> None:
        d_or_t = torch.logical_or(done, truncated)
        self._tracker = (2 * self._tracker + d_or_t).clip(0, 2)

    @property
    def all_done(self) -> bool:
        return self.num_envs_done == self._tracker.size(0)

    @property
    def num_envs_done(self) -> int:
        return (self._tracker > 0).sum() 

    @property
    def is_alive(self) -> np.ndarray:
        return torch.logical_not(self._tracker)

    @property
    def mask_new_dones(self) -> np.ndarray:
        return torch.logical_not(self._tracker[self._tracker <= 1])


class MainProcessEnv:
    def __init__(self, env_fn: Callable) -> None:
        self._env = env_fn()
        self._can_step, self._step_output = None, None
    
    def reset(self) -> np.ndarray:
        obs, info = self._env.reset()
        self._can_step = True
        return obs[None, ...], tuple([info])

    def step(self, action: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Any]:
        if self._can_step:
            obs, rew, done, trunc, info = self._env.step(action)  # o_r_d_t: obs, rew, done, trunc
            obs, rew, done, trunc, info = *(np.array([x]) for x in (obs, rew, done, trunc)), tuple([info])
            self._step_output = (obs, rew, done, trunc, info) 
            if done or trunc:
                self._can_step = False
        return self._step_output
    

class MultiProcessEnv:
    def __init__(self, env_fn: Callable, num_envs: int) -> None:
        self._num_envs = num_envs
        self._processes, self._parent_conns = [], []
        for child_id in range(num_envs):
            parent_conn, child_conn = Pipe()
            self._parent_conns.append(parent_conn)
            p = Process(target=child_env, args=(child_id, env_fn, child_conn), daemon=True)
            self._processes.append(p)
            p.start()

    def reset(self) -> np.ndarray:
        self._send([Message(MessageType.RESET) for _ in range(self._num_envs)])
        content = self._receive(check_type=MessageType.RESET_RETURN)
        obs, info = zip(*content)
        return np.stack(obs), info

    def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Any]:
        self._send([Message(MessageType.STEP, action) for action in actions]) 
        content = self._receive(check_type=MessageType.STEP_RETURN)
        obs, rew, done, truncated, info = zip(*content)
        return *(np.stack(x) for x in (obs, rew, done, truncated)), info
    
    def _send(self, messages: List[Message]) -> None:
        assert len(messages) == self._num_envs
        for message, parent_conn in zip(messages, self._parent_conns):
            parent_conn.send(message)

    def _receive(self, check_type: Optional[MessageType] = None) -> List[Any]:
        messages = tuple(parent_conn.recv() for parent_conn in self._parent_conns)
        assert check_type is None or all(m.type == check_type for m in messages)
        return (m.content for m in messages)
    
    def close(self) -> None:
        self._send([Message(MessageType.CLOSE) for _ in range(len(self.parent_conns))])
        for p in self._processes:
            p.join()
        for parent_conn in self.parent_conns:
            parent_conn.close()


class MessageType(Enum):
    RESET = 0
    RESET_RETURN = 1
    STEP = 2
    STEP_RETURN = 3
    CLOSE = 4


@dataclass
class Message:
    type: MessageType
    content: Optional[Any] = None

    def __iter__(self) -> Iterator:
        return iter(astuple(self))


def child_env(child_id: int, env_fn: Callable, child_conn: Connection) -> None:
    np.random.seed(child_id + np.random.randint(0, 2 ** 31 - 1))
    env = env_fn()
    can_step = True
    while True:
        message_type, content = child_conn.recv()
        if message_type == MessageType.RESET:
            obs, info = env.reset()
            can_step = True
            child_conn.send(Message(MessageType.RESET_RETURN, (obs, info)))
        elif message_type == MessageType.STEP:
            if can_step:
                obs, rew, done, truncated, info = env.step(content)
            if done or truncated:
                can_step = False  # wait for all envs to be done before starting again
            child_conn.send(Message(MessageType.STEP_RETURN, (obs, rew, done, truncated, info)))
        elif message_type == MessageType.CLOSE:
            child_conn.close()
            return
        else:
            raise NotImplementedError
