from typing import Any
from functools import partial
from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp
from flax import struct

from mfax.envs.base.base import BaseEnvironment, BaseGlobalState, BaseMFSequence


@struct.dataclass
class SampleLocalState():
    # --- state ---
    state: jax.Array
    time: int = 0


@struct.dataclass
class SampleGlobalState(BaseGlobalState):
    pass


@struct.dataclass
class SampleMFSequence(BaseMFSequence):
    vec_a: jax.Array
    vec_r: jax.Array


@struct.dataclass
class SampleEnvParams():
    n_agents: int


class SampleEnvironment(BaseEnvironment, ABC):
    """Abstract base class for all Pushforward environments. """

    @partial(jax.jit, static_argnames=("self",))
    def mf_step(
      self,
      key: jax.Array,
      vec_local_s: SampleLocalState,
      global_s: SampleGlobalState,
      vec_a: jax.Array,
    ) -> tuple[jax.Array, jax.Array, SampleLocalState, SampleLocalState, SampleGlobalState, SampleGlobalState, jax.Array, jax.Array, jax.Array, dict[Any, Any]]:
        key_step, key_reset = jax.random.split(key)
        
        vec_local_obs_st, vec_local_s_st, global_s_st, vec_r, global_terminated, global_truncated = self.mf_step_env(
            key_step, vec_local_s, global_s, vec_a
        )
        vec_local_obs_re, vec_local_s_re, global_s_re = self.mf_reset_env(key_reset)

        # --- choose between reset and non-reset states and observations based on whether environment is terminated or truncated. ---
        global_done = jnp.logical_or(global_terminated, global_truncated)
        global_s = jax.tree.map(
            lambda x, y: jax.lax.select(global_done, x, y), global_s_re, global_s_st
        )
        vec_local_s = jax.tree.map(
            lambda x, y: jax.lax.select(global_done, x, y), vec_local_s_re, vec_local_s_st
        )
        vec_local_obs = jax.lax.select(global_done, vec_local_obs_re, vec_local_obs_st)
        return vec_local_obs, vec_local_obs_st, vec_local_s, vec_local_s_st, global_s, global_s_st, vec_r, global_terminated, global_truncated, {}


    @partial(jax.jit, static_argnames=("self",))
    def mf_reset(
        self, 
        key: jax.Array,
    ) -> tuple[jax.Array, SampleGlobalState]:
        vec_local_obs, vec_local_s, global_s = self.mf_reset_env(key)
        return vec_local_obs, vec_local_s, global_s


    @abstractmethod
    def mf_step_env(
        self,
        key: jax.Array,
        vec_local_s: SampleLocalState,
        global_s: SampleGlobalState,
        vec_a: jax.Array,
    ) -> tuple[jax.Array, SampleLocalState, SampleGlobalState, jax.Array, jax.Array, jax.Array]:
        """
        Steps environment forward for a given global state and vector of actions for each agent.
        """
        raise NotImplementedError


    @abstractmethod
    def mf_reset_env(
        self, 
        key: jax.Array
    ) -> tuple[jax.Array, SampleLocalState, SampleGlobalState]:
        """
        Resets Mean Field distribution.
        """
        raise NotImplementedError


    @abstractmethod
    def _single_idio_step(self, key: jax.Array, local_s: SampleLocalState, action: jax.Array, global_s: SampleGlobalState) -> tuple[SampleLocalState]:
        """
        Returns the next local state for a single agent with idiosyncratic noise (i.e. stochastic step forward).
        """
        raise NotImplementedError
    

    @abstractmethod
    def _single_idio_reward(self, local_s: SampleLocalState, action: jax.Array, global_s: SampleGlobalState, next_global_s: SampleGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Returns reward for a current state, action and global state.
        """
        raise NotImplementedError
    

    def sa_step(self, key: jax.Array, mf_sequence: SampleMFSequence, local_s: SampleLocalState, action: jax.Array) -> tuple[jax.Array, SampleLocalState, jax.Array, jax.Array, jax.Array]:
        key_step, key_reset = jax.random.split(key)

        # --- local step is based on current global state ---
        global_s = jax.tree_map(lambda x: jnp.take(x, local_s.time, axis=0), mf_sequence.global_s)
        next_global_s = jax.tree_map(lambda x: jnp.take(x, local_s.time + 1, axis=0), mf_sequence.global_s)
        local_s_step, reward_step, reward_term = self.sa_step_env(key_step, local_s, action, global_s, next_global_s)
        local_obs_step = self.get_local_obs(local_s_step, next_global_s)
        local_s_reset = self.sa_reset_env(key_reset, next_global_s)
        local_obs_reset = self.get_local_obs(local_s_reset, next_global_s)

        # --- observation, termination and truncation are based on next global state ---
        global_terminated = jax.tree_map(lambda x: jnp.take(x, local_s.time + 1, axis=0), mf_sequence.global_terminated)
        global_truncated = jax.tree_map(lambda x: jnp.take(x, local_s.time + 1, axis=0), mf_sequence.global_truncated)

        # --- choose between reset and non-reset state based on whether environment is terminated or truncated ---
        global_done = jnp.logical_or(global_terminated, global_truncated)
        local_s = jax.tree.map(
            lambda x, y: jax.lax.select(global_done, x, y), local_s_reset, local_s_step
        )
        local_obs = jax.lax.select(global_done, local_obs_reset, local_obs_step)

        reward = jax.lax.select(global_done, reward_term, reward_step)
        return local_obs, local_s, reward, global_terminated, global_truncated


    def sa_reset(self, key: jax.Array, mf_sequence: SampleMFSequence) -> tuple[jax.Array, SampleLocalState]:
        """Resets single agent by sampling from mean field distribution. Single Agent branch is not self-sustaining, as next step requires next mean-field distribution."""
        global_s = jax.tree_map(lambda x: jnp.take(x, 0, axis=0), mf_sequence.global_s)
        local_s = self.sa_reset_env(key, global_s)
        local_obs = self.get_local_obs(local_s, global_s)
        return local_obs, local_s


    def sa_step_env(self, key: jax.Array, local_s: SampleLocalState, action: jax.Array, global_s: SampleGlobalState, next_global_s: SampleGlobalState) -> tuple[SampleLocalState, jax.Array, jax.Array]:
        """
        Unclosed step function for a single agent. Only moves one agent forward, so cannot return the updated global state or observation.
        """
        raise NotImplementedError


    def sa_reset_env(self, key: jax.Array, global_s: SampleGlobalState) -> SampleLocalState:
        """
        Resets single agent by sampling from mean field distribution.
        """
        raise NotImplementedError
    

    @abstractmethod
    def get_local_obs(self, local_s: SampleLocalState, global_s: SampleGlobalState) -> jax.Array:
        """Gets local observation from Mean Field."""
        raise NotImplementedError
    
    
    def normalize_obs(self, global_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        """
        Transform global observation for feeding into policy network. Must work on batched observations.
        """
        raise NotImplementedError


    def normalize_local_s(self, local_s: jax.Array, normalize_states: bool = False) -> jax.Array:
        """
        Transform local state for feeding into policy network. Must work on batched observations.
        """
        raise NotImplementedError