from functools import cached_property
from typing import Tuple

import chex
import jax.numpy as jnp
import numpy as np
from jumanji import specs
from jumanji.env import Environment, State
from jumanji.specs import Array, MultiDiscreteArray, Spec
from jumanji.types import TimeStep
from jumanji.wrappers import MultiToSingleWrapper, Wrapper

from stoix.base_types import Observation


class JumanjiWrapper(Wrapper):
    def __init__(
        self,
        env: Environment,
        observation_attribute: str,
        flatten_observation: bool = False,
        multi_agent: bool = False,
        use_action_mask: bool = False,
    ) -> None:
        if isinstance(env.action_spec, MultiDiscreteArray):
            env = MultiDiscreteToDiscrete(env)
        if multi_agent:
            env = MultiToSingleWrapper(env)

        self._env = env
        
        self._observation_attribute = observation_attribute
        self._flatten_observation = flatten_observation
        self._use_action_mask = use_action_mask
        
        if observation_attribute is None:
            self._use_true_observation = True
        else:
            self._use_true_observation = False
            self._obs_shape = self._env.observation_spec.__dict__[self._observation_attribute].shape

        if self._flatten_observation:
            self._obs_shape = (np.prod(self._obs_shape),)
        self._legal_action_mask = jnp.ones((self._env.action_spec.num_values,), dtype=jnp.bool_)

    def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
        state, timestep = self._env.reset(key)
        
        if hasattr(timestep.observation, "action_mask") and self._use_action_mask:
            legal_action_mask = timestep.observation.action_mask
        else:
            legal_action_mask = self._legal_action_mask
            
        if not self._use_true_observation:
            obs = timestep.observation._asdict()[self._observation_attribute].astype(jnp.float32)
            observation = Observation(
                obs.reshape(self._obs_shape), legal_action_mask, state.step_count
            )
        else:
            observation = timestep.observation
            if hasattr(observation, "action_mask"):
                observation = observation._replace(action_mask=legal_action_mask)
                
        timestep = timestep.replace(
            observation=observation,
            extras={},
        )
        return state, timestep

    def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
        state, timestep = self._env.step(state, action)
        
        if hasattr(timestep.observation, "action_mask") and self._use_action_mask:
            legal_action_mask = timestep.observation.action_mask
        else:
            legal_action_mask = self._legal_action_mask
            
        if not self._use_true_observation:
            obs = timestep.observation._asdict()[self._observation_attribute].astype(jnp.float32)
            observation = Observation(
                obs.reshape(self._obs_shape), legal_action_mask, state.step_count
            )
        else:
            observation = timestep.observation
            if hasattr(observation, "action_mask"):
                observation = observation._replace(action_mask=legal_action_mask)
        timestep = timestep.replace(
            observation=observation,
            extras={},
        )
        return state, timestep
    
    def action_spec(self) -> Spec:
        return self._env.action_spec
    
    def observation_spec(self) -> Spec:
        if self._use_true_observation:
            return self._env.observation_spec
        return specs.Spec(
            Observation,
            "ObservationSpec",
            agent_view=Array(shape=self._obs_shape, dtype=jnp.float32),
            action_mask=Array(shape=(self.action_spec().num_values,), dtype=jnp.bool_),
            step_count=Array(shape=(), dtype=jnp.int32),
        )


class MultiDiscreteToDiscrete(Wrapper):
    def __init__(self, env: Environment):
        super().__init__(env)
        self._action_spec_num_values = env.action_spec.num_values

    def apply_factorisation(self, x: chex.Array) -> chex.Array:
        """Applies the factorisation to the given action."""
        action_components = []
        flat_action = x
        n = self._action_spec_num_values.shape[0]
        for i in range(n - 1, 0, -1):
            flat_action, remainder = jnp.divmod(flat_action, self._action_spec_num_values[i])
            action_components.append(remainder)
        action_components.append(flat_action)
        action = jnp.stack(
            list(reversed(action_components)),
            axis=-1,
            dtype=self._action_spec_num_values.dtype,
        )
        return action

    def inverse_factorisation(self, y: chex.Array) -> chex.Array:
        """Inverts the factorisation of the given action."""
        n = self._action_spec_num_values.shape[0]
        action_components = jnp.split(y, n, axis=-1)
        flat_action = action_components[0]
        for i in range(1, n):
            flat_action = self._action_spec_num_values[i] * flat_action + action_components[i]
        return flat_action

    def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
        action = self.apply_factorisation(action)
        state, timestep = self._env.step(state, action)
        return state, timestep
    
    @cached_property
    def action_spec(self) -> specs.Spec:
        """Returns the action spec of the environment."""
        original_action_spec = self._env.action_spec
        num_actions = int(np.prod(np.asarray(original_action_spec.num_values)))
        return specs.DiscreteArray(num_actions, name="action")


class MultiBoundedToBounded(Wrapper):
    def __init__(self, env: Environment):
        super().__init__(env)
        self._true_action_shape = env.action_spec.shape

    def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
        action = action.reshape(self._true_action_shape)
        state, timestep = self._env.step(state, action)
        return state, timestep

    def action_spec(self) -> specs.Spec:
        """Returns the action spec of the environment."""
        original_action_spec = self._env.action_spec
        size = int(np.prod(np.asarray(original_action_spec.shape)))
        return specs.BoundedArray(
            (size,),
            minimum=original_action_spec.minimum,
            maximum=original_action_spec.maximum,
            dtype=original_action_spec.dtype,
            name="action",
        )
