import numpy as np
import multiprocessing as mp
import time
import sys
from enum import Enum
from copy import deepcopy

from gym import logger
from gym.vector.vector_env import VectorEnv
from gym.error import (
    AlreadyPendingCallError,
    NoAsyncCallError,
    ClosedEnvironmentError,
    CustomSpaceError,
)
from gym.vector.utils import (
    create_shared_memory,
    create_empty_array,
    write_to_shared_memory,
    read_from_shared_memory,
    concatenate,
    CloudpickleWrapper,
    clear_mpi_env_vars,
)

__all__ = ["AsyncVectorEnv"]


class AsyncState(Enum):
    DEFAULT = "default"
    WAITING_RESET = "reset"
    WAITING_STEP = "step"
    WAITING_CALL = "call"


class AsyncVectorEnv(VectorEnv):
    """Vectorized environment that runs multiple environments in parallel. It
    uses `multiprocessing` processes, and pipes for communication.
    新增特性：子进程崩溃检测与重启、BrokenPipe容错、1分钟重试机制
    """

    def __init__(
        self,
        env_fns,
        dummy_env_fn=None,
        observation_space=None,
        action_space=None,
        shared_memory=True,
        copy=True,
        context=None,
        daemon=True,
        worker=None,
        retry_interval=120,  # 新增：重试等待时间（秒），默认1分钟
        max_retries=1,       # 新增：最大重试次数，避免无限循环
    ):
        ctx = mp.get_context(context)
        self.env_fns = env_fns  # 保存环境构造函数，用于重启子进程
        self.shared_memory = shared_memory
        self.copy = copy
        self.retry_interval = retry_interval  # 重试等待时间
        self.max_retries = max_retries        # 最大重试次数
        self.ctx = ctx                        # 保存进程上下文，用于重启子进程

        # 原有dummy_env逻辑不变
        if dummy_env_fn is None:
            dummy_env_fn = env_fns[0]
        dummy_env = dummy_env_fn()
        self.metadata = dummy_env.metadata

        if (observation_space is None) or (action_space is None):
            observation_space = observation_space or dummy_env.observation_space
            action_space = action_space or dummy_env.action_space
        dummy_env.close()
        del dummy_env
        super(AsyncVectorEnv, self).__init__(
            num_envs=len(env_fns),
            observation_space=observation_space,
            action_space=action_space,
        )

        # 原有共享内存逻辑不变
        if self.shared_memory:
            try:
                self._obs_buffer = create_shared_memory(  # 改为实例变量，重启子进程需复用
                    self.single_observation_space, n=self.num_envs, ctx=ctx
                )
                self.observations = read_from_shared_memory(
                    self._obs_buffer, self.single_observation_space, n=self.num_envs
                )
            except CustomSpaceError:
                raise ValueError(
                    "Using `shared_memory=True` in `AsyncVectorEnv` "
                    "is incompatible with non-standard Gym observation spaces "
                    "(i.e. custom spaces inheriting from `gym.Space`), and is "
                    "only compatible with default Gym spaces (e.g. `Box`, "
                    "`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
                    "if you use custom observation spaces."
                )
        else:
            self._obs_buffer = None  # 改为实例变量
            self.observations = create_empty_array(
                self.single_observation_space, n=self.num_envs, fn=np.zeros
            )

        # 初始化子进程和管道（抽离为独立方法，便于重启）
        self.parent_pipes, self.processes = [], []
        self.error_queue = ctx.Queue()
        self.target_worker = worker or (_worker_shared_memory if self.shared_memory else _worker)
        self._start_workers(daemon=daemon)  # 启动子进程

        self._state = AsyncState.DEFAULT
        self._check_observation_spaces()

    def _start_workers(self, daemon=True):
        """新增：独立的子进程启动方法，用于初始化和重启"""
        # 先关闭已存在的管道和进程（避免资源泄漏）
        for pipe in self.parent_pipes:
            if pipe and not pipe.closed:
                pipe.close()
        for process in self.processes:
            if process and process.is_alive():
                process.terminate()
                process.join(timeout=5)  # 等待进程退出，最多5秒

        # 清空列表，重新创建子进程
        self.parent_pipes.clear()
        self.processes.clear()

        with clear_mpi_env_vars():
            for idx, env_fn in enumerate(self.env_fns):
                parent_pipe, child_pipe = self.ctx.Pipe()
                # 启动子进程（复用target_worker和_obs_buffer）
                process = self.ctx.Process(
                    target=self.target_worker,
                    name="Worker<{0}>-{1}".format(type(self).__name__, idx),
                    args=(
                        idx,
                        CloudpickleWrapper(env_fn),
                        child_pipe,
                        parent_pipe,
                        self._obs_buffer,
                        self.error_queue,
                        self.retry_interval,  # 传递重试间隔给子进程
                    ),
                )
                process.daemon = daemon
                process.start()
                child_pipe.close()  # 主进程关闭子管道端

                self.parent_pipes.append(parent_pipe)
                self.processes.append(process)
        logger.info(f"Successfully started {len(self.processes)} workers")

    def _check_worker_alive(self):
        """新增：检查所有子进程是否存活，返回死亡的进程索引"""
        dead_indices = []
        for idx, process in enumerate(self.processes):
            if not process.is_alive():
                dead_indices.append(idx)
                logger.warn(f"Worker-{idx} is dead, need to restart")
        return dead_indices

    def _retry_with_worker_restart(self, func, *args, **kwargs):
        """新增：带重试和子进程重启的通用逻辑包装器"""
        retries = 0
        while retries < self.max_retries:
            try:
                # 先检查子进程是否存活，若有死亡则重启
                dead_indices = self._check_worker_alive()
                if dead_indices:
                    logger.warn(f"Restarting all workers (due to dead workers: {dead_indices})")
                    self._start_workers(daemon=self.processes[0].daemon)  # 重启所有子进程
                    # 重启后重置状态（避免状态不一致）
                    self._state = AsyncState.DEFAULT
                    time.sleep(2)  # 等待重启完成，避免立即操作

                # 执行目标函数（如reset_async、step_async等）
                return func(*args, **kwargs)
            except (BrokenPipeError, EOFError, mp.TimeoutError) as e:
                retries += 1
                wait_time = self.retry_interval
                logger.error(
                    f"Encountered error: {type(e).__name__}: {e}\n"
                    f"Retry {retries}/{self.max_retries} after {wait_time} seconds"
                )
                time.sleep(wait_time)  # 等待指定时间后重试

        # 超过最大重试次数，抛出异常
        raise RuntimeError(
            f"Failed after {self.max_retries} retries. All workers cannot be recovered."
        )

    # ------------------------------ 原有方法修改 ------------------------------
    def seed(self, seeds=None):
        # 用重试包装器包裹原有逻辑
        def _seed():
            self._assert_is_running()
            if seeds is None:
                seeds_list = [None for _ in range(self.num_envs)]
            elif isinstance(seeds, int):
                seeds_list = [seeds + i for i in range(self.num_envs)]
            else:
                seeds_list = seeds
            assert len(seeds_list) == self.num_envs

            if self._state != AsyncState.DEFAULT:
                raise AlreadyPendingCallError(
                    "Calling `seed` while waiting "
                    f"for a pending call to `{self._state.value}` to complete.",
                    self._state.value,
                )

            # 发送seed命令前检查管道是否可用
            for pipe, seed in zip(self.parent_pipes, seeds_list):
                if pipe.closed:
                    raise BrokenPipeError("Parent pipe is closed, cannot send seed command")
                pipe.send(("seed", seed))
            _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
            self._raise_if_errors(successes)
        return self._retry_with_worker_restart(_seed)

    def reset_async(self):
        def _reset_async():
            self._assert_is_running()
            if self._state != AsyncState.DEFAULT:
                raise AlreadyPendingCallError(
                    "Calling `reset_async` while waiting "
                    f"for a pending call to `{self._state.value}` to complete",
                    self._state.value,
                )

            # 发送reset命令前检查管道状态
            for pipe in self.parent_pipes:
                if pipe.closed:
                    raise BrokenPipeError("Parent pipe is closed, cannot send reset command")
                pipe.send(("reset", None))
            self._state = AsyncState.WAITING_RESET
        return self._retry_with_worker_restart(_reset_async)

    def reset_wait(self, timeout=None):
        def _reset_wait(timeout=None):  # 修复：添加timeout参数
            self._assert_is_running()
            if self._state != AsyncState.WAITING_RESET:
                raise NoAsyncCallError(
                    "Calling `reset_wait` without any prior " "call to `reset_async`.",
                    AsyncState.WAITING_RESET.value,
                )

            if not self._poll(timeout):
                self._state = AsyncState.DEFAULT
                raise mp.TimeoutError(
                    f"The call to `reset_wait` has timed out after {timeout} second(s)."
                )

            # 接收结果前检查管道状态
            results = []
            successes = []
            for pipe in self.parent_pipes:
                if pipe.closed:
                    raise BrokenPipeError("Parent pipe is closed, cannot receive reset result")
                res, succ = pipe.recv()
                results.append(res)
                successes.append(succ)
            self._raise_if_errors(successes)
            self._state = AsyncState.DEFAULT

            if not self.shared_memory:
                self.observations = concatenate(
                    results, self.observations, self.single_observation_space
                )

            return deepcopy(self.observations) if self.copy else self.observations
        return self._retry_with_worker_restart(_reset_wait, timeout=timeout)

    def step_async(self, actions):
        def _step_async():
            self._assert_is_running()
            if self._state != AsyncState.DEFAULT:
                raise AlreadyPendingCallError(
                    "Calling `step_async` while waiting "
                    f"for a pending call to `{self._state.value}` to complete.",
                    self._state.value,
                )

            # 发送step命令前检查管道状态
            for pipe, action in zip(self.parent_pipes, actions):
                if pipe.closed:
                    raise BrokenPipeError("Parent pipe is closed, cannot send step command")
                pipe.send(("step", action))
            self._state = AsyncState.WAITING_STEP
        return self._retry_with_worker_restart(_step_async)

    def step_wait(self, timeout=None):
        def _step_wait(timeout=None):  # 修复：添加timeout参数
            self._assert_is_running()
            if self._state != AsyncState.WAITING_STEP:
                raise NoAsyncCallError(
                    "Calling `step_wait` without any prior call " "to `step_async`.",
                    AsyncState.WAITING_STEP.value,
                )

            if not self._poll(timeout):
                self._state = AsyncState.DEFAULT
                raise mp.TimeoutError(
                    f"The call to `step_wait` has timed out after {timeout} second(s)."
                )

            # 接收结果前检查管道状态
            results = []
            successes = []
            for pipe in self.parent_pipes:
                if pipe.closed:
                    raise BrokenPipeError("Parent pipe is closed, cannot receive step result")
                res, succ = pipe.recv()
                results.append(res)
                successes.append(succ)
            self._raise_if_errors(successes)
            self._state = AsyncState.DEFAULT

            observations_list, rewards, dones, infos = zip(*results)
            if not self.shared_memory:
                self.observations = concatenate(
                    observations_list, self.observations, self.single_observation_space
                )

            return (
                deepcopy(self.observations) if self.copy else self.observations,
                np.array(rewards),
                np.array(dones, dtype=np.bool_),
                infos,
            )
        return self._retry_with_worker_restart(_step_wait, timeout=timeout)

    # ------------------------------ 原有方法（未修改核心逻辑） ------------------------------
    def close_extras(self, timeout=None, terminate=False):
        timeout = 0 if terminate else timeout
        try:
            if self._state != AsyncState.DEFAULT:
                logger.warn(
                    "Calling `close` while waiting for a pending "
                    f"call to `{self._state.value}` to complete."
                )
                function = getattr(self, "{0}_wait".format(self._state.value))
                function(timeout)
        except mp.TimeoutError:
            terminate = True

        if terminate:
            for process in self.processes:
                if process.is_alive():
                    process.terminate()
        else:
            for pipe in self.parent_pipes:
                if (pipe is not None) and (not pipe.closed):
                    pipe.send(("close", None))
            for pipe in self.parent_pipes:
                if (pipe is not None) and (not pipe.closed):
                    pipe.recv()

        for pipe in self.parent_pipes:
            if pipe is not None:
                pipe.close()
        for process in self.processes:
            process.join()

    def _poll(self, timeout=None):
        self._assert_is_running()
        if timeout is None:
            return True
        end_time = time.perf_counter() + timeout
        delta = None
        for pipe in self.parent_pipes:
            delta = max(end_time - time.perf_counter(), 0)
            if pipe is None or pipe.closed:
                return False
            if not pipe.poll(delta):
                return False
        return True

    def _check_observation_spaces(self):
        self._assert_is_running()
        for pipe in self.parent_pipes:
            pipe.send(("_check_observation_space", self.single_observation_space))
        same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        if not all(same_spaces):
            raise RuntimeError(
                "Some environments have an observation space "
                f"different from `{self.single_observation_space}`. In order to batch observations, the "
                "observation spaces from all environments must be "
                "equal."
            )

    def _assert_is_running(self):
        if self.closed:
            raise ClosedEnvironmentError(
                f"Trying to operate on `{type(self).__name__}`, after a "
                "call to `close()`."
            )

    def _raise_if_errors(self, successes):
        if all(successes):
            return

        num_errors = self.num_envs - sum(successes)
        assert num_errors > 0
        exctype, value = None, None
        for _ in range(num_errors):
            index, exctype, value = self.error_queue.get()
            logger.error(
                f"Received error from Worker-{index}: {exctype.__name__}: {value}"
            )
            logger.error(f"Shutting down Worker-{index}")
            if self.parent_pipes[index] and not self.parent_pipes[index].closed:
                self.parent_pipes[index].close()
            self.parent_pipes[index] = None

        logger.error("Raising the last exception to main process")
        raise exctype(value)

    # ------------------------------ 原有call/set_attr/render方法 ------------------------------
    def call_async(self, name: str, *args, **kwargs):
        return self._retry_with_worker_restart(self._call_async_impl, name, *args, **kwargs)

    def _call_async_impl(self, name: str, *args, **kwargs):
        self._assert_is_running()
        if self._state != AsyncState.DEFAULT:
            raise AlreadyPendingCallError(
                f"Calling `call_async` while waiting for a pending call to `{self._state.value}` to complete.",
                self._state.value,
            )

        for pipe in self.parent_pipes:
            if pipe.closed:
                raise BrokenPipeError("Parent pipe is closed, cannot send call command")
            pipe.send(("_call", (name, args, kwargs)))
        self._state = AsyncState.WAITING_CALL

    def call_wait(self, timeout=None):
        return self._retry_with_worker_restart(self._call_wait_impl, timeout=timeout)

    def _call_wait_impl(self, timeout=None):  # 修复：添加timeout参数
        self._assert_is_running()
        if self._state != AsyncState.WAITING_CALL:
            raise NoAsyncCallError(
                "Calling `call_wait` without any prior call to `call_async`.",
                AsyncState.WAITING_CALL.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `call_wait` has timed out after {timeout} second(s)."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT
        return results

    def call(self, name: str, *args, **kwargs):
        self.call_async(name, *args, **kwargs)
        return self.call_wait()

    def call_each(self, name: str, args_list: list = None, kwargs_list: list = None, timeout=None):
        return self._retry_with_worker_restart(
            self._call_each_impl, name, args_list, kwargs_list, timeout=timeout
        )

    def _call_each_impl(self, name: str, args_list: list = None, kwargs_list: list = None, timeout=None):
        n_envs = len(self.parent_pipes)
        args_list = args_list or [[]] * n_envs
        assert len(args_list) == n_envs
        kwargs_list = kwargs_list or [dict()] * n_envs
        assert len(kwargs_list) == n_envs

        self._assert_is_running()
        if self._state != AsyncState.DEFAULT:
            raise AlreadyPendingCallError(
                f"Calling `call_async` while waiting for a pending call to `{self._state.value}` to complete.",
                self._state.value,
            )

        for i, (pipe, args, kwargs) in enumerate(zip(self.parent_pipes, args_list, kwargs_list)):
            if pipe.closed:
                raise BrokenPipeError(f"Parent pipe-{i} is closed, cannot send call_each command")
            pipe.send(("_call", (name, args, kwargs)))
        self._state = AsyncState.WAITING_CALL

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `call_wait` has timed out after {timeout} second(s)."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT
        return results

    def set_attr(self, name: str, values):
        return self._retry_with_worker_restart(self._set_attr_impl, name, values)

    def _set_attr_impl(self, name: str, values):
        self._assert_is_running()
        if not isinstance(values, (list, tuple)):
            values = [values] * self.num_envs
        if len(values) != self.num_envs:
            raise ValueError(
                f"Values must be list/tuple with length {self.num_envs}. Got {len(values)} values."
            )

        if self._state != AsyncState.DEFAULT:
            raise AlreadyPendingCallError(
                f"Calling `set_attr` while waiting for a pending call to `{self._state.value}` to complete.",
                self._state.value,
            )

        for pipe, value in zip(self.parent_pipes, values):
            if pipe.closed:
                raise BrokenPipeError("Parent pipe is closed, cannot send set_attr command")
            pipe.send(("_setattr", (name, value)))
        _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)

    def render(self, *args, **kwargs):
        return self.call("render", *args, **kwargs)


# ------------------------------ 子进程Worker函数修改 ------------------------------
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue, retry_interval=60):
    """修改：子进程增加管道状态检测，避免BrokenPipe崩溃"""
    assert shared_memory is None
    env = None
    parent_pipe.close()
    try:
        env = env_fn()  # 初始化环境
        while True:
            # 关键修改1：检测管道是否已关闭（避免recv崩溃）
            if pipe.closed:
                logger.warn(f"Worker-{index}: Pipe is closed, trying to reconnect after {retry_interval}s")
                time.sleep(retry_interval)
                # 管道关闭后无法恢复，退出循环（主进程会重启子进程）
                break

            # 关键修改2：用非阻塞poll+超时，避免CPU调度导致的永久阻塞
            if not pipe.poll(timeout=retry_interval):  # 超时时间=重试间隔
                logger.warn(f"Worker-{index}: No command received for {retry_interval}s, checking pipe again")
                continue

            # 接收命令（此时管道已确认可用）
            command, data = pipe.recv()
            if command == "reset":
                observation = env.reset()
                pipe.send((observation, True))
            elif command == "step":
                observation, reward, done, info = env.step(data)
                pipe.send(((observation, reward, done, info), True))
            elif command == "seed":
                env.seed(data)
                pipe.send((None, True))
            elif command == "close":
                pipe.send((None, True))
                break  # 正常关闭，退出循环
            elif command == "_call":
                name, args, kwargs = data
                if name in ["reset", "step", "seed", "close"]:
                    raise ValueError(f"Trying to call `{name}` with `_call`. Use `{name}` directly.")
                function = getattr(env, name)
                result = function(*args, **kwargs) if callable(function) else function
                pipe.send((result, True))
            elif command == "_setattr":
                name, value = data
                setattr(env, name, value)
                pipe.send((None, True))
            elif command == "_check_observation_space":
                pipe.send((data == env.observation_space, True))
            else:
                raise RuntimeError(f"Unknown command `{command}`")
    except EOFError:
        # 关键修改3：捕获EOFError（管道中断），不崩溃，等待主进程重启
        logger.error(f"Worker-{index}: EOFError (pipe broken), exiting to restart")
        error_queue.put((index, EOFError, "Pipe broken (EOF received)"))
    except BrokenPipeError:
        logger.error(f"Worker-{index}: BrokenPipeError, exiting to restart")
        error_queue.put((index, BrokenPipeError, "Pipe broken (send failed)"))
    except Exception as e:
        # 其他异常（如环境错误），仍需上报
        logger.error(f"Worker-{index}: Unexpected error: {type(e).__name__}: {e}")
        error_queue.put((index, type(e), str(e)))
    finally:
        # 清理环境资源
        if env:
            env.close()
        # 关闭管道（避免资源泄漏）
        if not pipe.closed:
            pipe.close()
        logger.info(f"Worker-{index}: Exited gracefully, waiting for restart")


def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue, retry_interval=60):
    """修改：共享内存版本的子进程，同_worker的容错逻辑"""
    assert shared_memory is not None
    env = None
    parent_pipe.close()
    try:
        env = env_fn()
        observation_space = env.observation_space
        while True:
            if pipe.closed:
                logger.warn(f"Worker-{index}(shared): Pipe closed, reconnect after {retry_interval}s")
                time.sleep(retry_interval)
                break

            if not pipe.poll(timeout=retry_interval):
                logger.warn(f"Worker-{index}(shared): No command for {retry_interval}s")
                continue

            command, data = pipe.recv()
            if command == "reset":
                observation = env.reset()
                write_to_shared_memory(index, observation, shared_memory, observation_space)
                pipe.send((None, True))
            elif command == "step":
                observation, reward, done, info = env.step(data)
                write_to_shared_memory(index, observation, shared_memory, observation_space)
                pipe.send(((None, reward, done, info), True))
            elif command == "seed":
                env.seed(data)
                pipe.send((None, True))
            elif command == "close":
                pipe.send((None, True))
                break
            elif command == "_call":
                name, args, kwargs = data
                if name in ["reset", "step", "seed", "close"]:
                    raise ValueError(f"Trying to call `{name}` with `_call`. Use `{name}` directly.")
                function = getattr(env, name)
                result = function(*args, **kwargs) if callable(function) else function
                pipe.send((result, True))
            elif command == "_setattr":
                name, value = data
                setattr(env, name, value)
                pipe.send((None, True))
            elif command == "_check_observation_space":
                pipe.send((data == observation_space, True))
            else:
                raise RuntimeError(f"Unknown command `{command}`")
    except EOFError:
        logger.error(f"Worker-{index}(shared): EOFError (pipe broken)")
        error_queue.put((index, EOFError, "Pipe broken (EOF received)"))
    except BrokenPipeError:
        logger.error(f"Worker-{index}(shared): BrokenPipeError")
        error_queue.put((index, BrokenPipeError, "Pipe broken (send failed)"))
    except Exception as e:
        logger.error(f"Worker-{index}(shared): Unexpected error: {type(e).__name__}: {e}")
        error_queue.put((index, type(e), str(e)))
    finally:
        if env:
            env.close()
        if not pipe.closed:
            pipe.close()
        logger.info(f"Worker-{index}(shared): Exited gracefully, waiting for restart")
