from typing import Any
from functools import partial
from abc import ABC, abstractmethod
from typing import Optional

import jax
import jax.numpy as jnp
from flax import struct

from mfax.envs.base.base import BaseEnvironment, BaseMFSequence, BaseGlobalState

@struct.dataclass
class PushforwardGlobalState(BaseGlobalState):
    m: jax.Array


@struct.dataclass
class PushforwardEnvParams():
    states: jax.Array


@struct.dataclass
class PushforwardMFSequence(BaseMFSequence):
    global_obs: jax.Array
    global_hidden: Optional[jax.Array]
    prob_a: jax.Array
    mat_r: jax.Array


class PushforwardEnvironment(BaseEnvironment, ABC):
    """Abstract base class for all Pushforward environments. """

    @partial(jax.jit, static_argnames=("self",))
    def mf_step(
      self,
      key: jax.Array,
      global_s: PushforwardGlobalState,
      prob_a: jax.Array,
    ) -> tuple[jax.Array, jax.Array, PushforwardGlobalState, PushforwardGlobalState, jax.Array, jax.Array, jax.Array, dict[Any, Any]]:
        key_step, key_reset = jax.random.split(key)
        
        global_obs_st, global_s_st, mat_r, global_terminated, global_truncated = self.mf_step_env(
            key_step, global_s, prob_a
        )
        
        global_obs_re, global_s_re = self.mf_reset_env(key_reset)

        # --- Choose between reset and non-reset state based on whether the environment is terminated or truncated. ---
        global_done = jnp.logical_or(global_terminated, global_truncated)
        global_s = jax.tree.map(
            lambda x, y: jax.lax.select(global_done, x, y), global_s_re, global_s_st
        )
        global_obs = jax.lax.select(global_done, global_obs_re, global_obs_st)
        return global_obs, global_obs_st, global_s, global_s_st, mat_r, global_terminated, global_truncated, {}


    @partial(jax.jit, static_argnames=("self",))
    def mf_reset(
        self, key: jax.Array,
        ) -> tuple[jax.Array, PushforwardGlobalState]:
        global_obs, global_s = self.mf_reset_env(key)
        return global_obs, global_s


    @partial(jax.jit, static_argnames=("self",))
    def mf_expected_value(self, vec: jax.Array, prob_a: jax.Array, global_s: PushforwardGlobalState) -> jax.Array:
        
        # --- vmap over states ---
        def single_state(i):
            return jax.vmap(self._single_pushforward_step, in_axes=(None, 0, None))(self.state_indices[i], jnp.arange(self.n_actions), global_s)

        next_state_idxs, next_state_probs = jax.vmap(single_state, in_axes=(0))(jnp.arange(self.n_states))
        expected_values = jnp.sum(vec[next_state_idxs] * next_state_probs * prob_a[..., None], axis=(1, 2))

        # --- no stop gradient ---
        return expected_values


    @partial(jax.jit, static_argnames=("self",))
    def mf_transition(self, m: jax.Array, prob_a: jax.Array, global_s: PushforwardGlobalState) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:

        # --- vmap over states ---
        def single_state(i):
            # --- vmap over actions ---
            return jax.vmap(self._single_pushforward_step, in_axes=(None, 0, None))(self.state_indices[i], jnp.arange(self.n_actions), global_s)

        next_state_idxs, next_state_probs = jax.vmap(single_state, in_axes=(0))(jnp.arange(self.n_states))
        next_m = jnp.zeros((self.n_states,)).at[next_state_idxs.reshape(-1)].add((m[..., None, None] * next_state_probs * prob_a[..., None]).reshape(-1))
        return next_m


    @partial(jax.jit, static_argnames=("self",))
    def mf_reward(self, global_s: PushforwardGlobalState, next_global_s: PushforwardGlobalState) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:

        # --- vmap over states ---
        def single_state(i):
            # --- vmap over actions ---
            return jax.vmap(self._single_pushforward_reward, in_axes=(None, 0, None, None))(self.state_indices[i], jnp.arange(self.n_actions), global_s, next_global_s)

        mat_r_step, mat_r_term = jax.vmap(single_state, in_axes=(0))(jnp.arange(self.n_states))
        return mat_r_step, mat_r_term


    @abstractmethod
    def mf_step_env(
        self,
        key: jax.Array,
        global_s: PushforwardGlobalState,
        prob_a: jax.Array,
    ) -> tuple[jax.Array, PushforwardGlobalState, jax.Array, jax.Array, jax.Array]:
        raise NotImplementedError


    @abstractmethod
    def mf_reset_env(
        self, key: jax.Array
    ) -> tuple[jax.Array, PushforwardGlobalState]:
        """Resets Mean Field distribution."""
        raise NotImplementedError


    @abstractmethod
    def _single_pushforward_step(self, state_idx: int, action: int, global_s: PushforwardGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Returns the next indices and probabilities of the next state for a current state, action and global state.
        """
        raise NotImplementedError
    

    @abstractmethod
    def _single_pushforward_reward(self, state_idx: int, action: int, global_s: PushforwardGlobalState, next_global_s: PushforwardGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Calculates the (expected, if depends on next state) reward for a single pushforward step.
        """
        raise NotImplementedError


    @abstractmethod
    def get_global_obs(self, global_s: PushforwardGlobalState) -> jax.Array:
        """
        Gets global observation from Mean Field.
        """
        raise NotImplementedError


    def normalize_obs(self, global_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        """
        Transform global observation for feeding into policy network. Must work on batched observations.
        """
        raise NotImplementedError


    def normalize_local_s(self, local_s: jax.Array, normalize_states: bool = False) -> jax.Array:
        """
        Transform local state for feeding into policy network. Must work on batched observations.
        """
        raise NotImplementedError