from __future__ import annotations

import traceback
from typing import Any

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from parllel import Array, dict_map
from parllel.cages.collections import ActionType, EnvResetType, EnvStepType
from parllel.tree.array_dict import ArrayDict
from parllel.types import BatchSpec


class DummyEnv(gym.Env):
    def __init__(
        self,
        action_space: gym.Space,
        observation_space: gym.Space,
        episode_length: int,
        batch_spec: BatchSpec,
        n_batches: int,
        multireward: bool = False,
        reset_automatically: bool = True,
    ) -> None:
        self.observation_space = observation_space
        self.action_space = action_space
        if multireward:
            self.reward_space = spaces.Dict(
                {
                    "alice": spaces.Box(-10, 10, shape=()),
                    "bob": spaces.Box(-10, 10, shape=()),
                }
            )
        else:
            self.reward_space = spaces.Box(-10, 10, shape=())

        self.episode_length = episode_length
        self.batch_spec = batch_spec
        self.reset_automatically = reset_automatically

        # allocate sample tree to store data generated by this env
        self._step_ctr = 0
        self._samples = ArrayDict()
        self._samples["observation"] = dict_map(
            Array.from_numpy,
            self.observation_space.sample(),
            batch_shape=(n_batches * batch_spec.T,),
            padding=1,
        )
        self._samples["reward"] = dict_map(
            Array.from_numpy,
            self.reward_space.sample(),
            batch_shape=(n_batches * batch_spec.T,),
            feature_shape=(),
            dtype=np.float32,
        )
        self._samples["terminated"] = Array(
            feature_shape=(),
            batch_shape=(n_batches * batch_spec.T,),
            dtype=bool,
        )
        self._samples["truncated"] = Array(
            feature_shape=(),
            batch_shape=(n_batches * batch_spec.T,),
            dtype=bool,
        )
        self._samples["env_info"] = dict_map(
            Array.from_numpy,
            {"action": self.action_space.sample()},
            batch_shape=(n_batches * batch_spec.T,),
        )
        self._batch_resets = Array(
            feature_shape=(),
            batch_shape=(n_batches * batch_spec.T,),
            dtype=bool,
            padding=1,
        )

    def step(self, action: ActionType) -> EnvStepType:
        obs = self.observation_space.sample()
        reward = self.reward_space.sample()
        terminated = self._traj_counter >= self.episode_length
        truncated = False
        env_info = {"action": dict_map(np.asarray, action).copy()}

        # check the call stack to determine if this is a "real sample" or not
        # if just getting example or decorrelating, keep overwriting "reset"
        # observation
        names = [frame.name for frame in traceback.extract_stack()]
        if "random_step_async" in names:
            self._samples["observation"][0] = obs
        else:
            self._samples["observation"][self._step_ctr + 1] = obs
            self._samples["reward"][self._step_ctr] = reward
            self._samples["terminated"][self._step_ctr] = terminated
            self._samples["truncated"][self._step_ctr] = truncated
            self._samples["env_info"][self._step_ctr] = env_info
            self._step_ctr += 1

        self._traj_counter += 1
        return (obs, reward, terminated, truncated, env_info)

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> EnvResetType:
        self._traj_counter = 1
        self._batch_resets[self._step_ctr - 1] = True
        if not self.reset_automatically:
            # sampling batch may have stopped early
            batch_ctr = (self._step_ctr - 1) // self.batch_spec.T + 1
            # advance counter to the next batch
            self._step_ctr = batch_ctr * self.batch_spec.T
        obs = self.observation_space.sample()
        self._samples["observation"][self._step_ctr] = obs
        env_info = (
            self._samples["env_info"][self._step_ctr - 1].to_ndarray().copy()
            if self._step_ctr > 0
            else None
        )
        return obs, env_info

    @property
    def samples(self):
        return self._samples

    @property
    def resets(self):
        return self._batch_resets
