import jax
import jax.numpy as jnp
from flax import struct

from mfax.envs.base.toy.beach_bar_1d import BaseBeachBar1DEnvParams, BaseBeachBar1DEnvironment, BaseBeachBar1DGlobalState
from mfax.envs.sample.base import SampleEnvironment, SampleEnvParams, SampleLocalState, SampleGlobalState


@struct.dataclass
class SampleBeachBar1DLocalState(
      SampleLocalState
      ):
  pass


@struct.dataclass
class SampleBeachBar1DGlobalState(
      SampleGlobalState, 
      BaseBeachBar1DGlobalState
      ):
      m: jax.Array
      bar_loc: jax.Array


@struct.dataclass
class SampleBeachBar1DEnvParams(
      SampleEnvParams, 
      BaseBeachBar1DEnvParams
      ):
      # number of agents representing mean field
      n_agents: int = 10000


class SampleBeachBar1DEnvironment(
      SampleEnvironment, 
      BaseBeachBar1DEnvironment
      ):
      
  @property
  def obs_dim(self) -> int:
    if self.params.partially_observable:
      return 3
    else:
      return self.params.num_states + 3

  def mf_step_env(
        self,
        key: jax.Array,
        local_s: SampleBeachBar1DLocalState,
        global_s: SampleBeachBar1DGlobalState,
        vec_a: jax.Array,
    ) -> tuple[jax.Array, SampleBeachBar1DLocalState, SampleBeachBar1DGlobalState, jax.Array, jax.Array, jax.Array]:
        """
        Steps environment forward for a given global state and vector of actions for each agent.
        """
        step_rng = jax.random.split(key, self.params.n_agents)
        next_local_s = jax.vmap(self._single_idio_step, in_axes=(0, 0, 0, None))(step_rng, local_s, vec_a, global_s)
        next_m = jnp.bincount(next_local_s.state, length=self.params.num_states) / self.params.n_agents
        next_time = global_s.time + 1
        next_z = jax.lax.select(
            self.params.common_noise & (next_time == (self.params.max_steps_in_episode // 2)),
            jax.random.bernoulli(key).astype(jnp.int32),
            global_s.z.astype(jnp.int32),
        )
        next_global_s = SampleBeachBar1DGlobalState(
            m=next_m, 
            z=next_z, 
            time=next_time, 
            bar_loc=global_s.bar_loc
        )
        next_local_obs = jax.vmap(self.get_local_obs, in_axes=(0, None))(next_local_s, next_global_s)
        
        # --- check for termination and truncation ---
        terminated = self.is_terminal(next_time)
        truncated = self.is_truncated(next_time)

        # --- select between step and terminated reward ---
        vec_r_term, vec_r_st = jax.vmap(self._single_idio_reward, in_axes=(0, 0, None, None))(local_s, vec_a, global_s, next_global_s)
        vec_r = jax.lax.select(terminated, vec_r_term, vec_r_st)
        return (
          jax.lax.stop_gradient(next_local_obs), 
          jax.lax.stop_gradient(next_local_s), 
          jax.lax.stop_gradient(next_global_s), 
          jax.lax.stop_gradient(vec_r), 
          jax.lax.stop_gradient(terminated), 
          jax.lax.stop_gradient(truncated)
        )


  def mf_reset_env(
        self, 
        key: jax.Array
    ) -> tuple[jax.Array, SampleBeachBar1DLocalState, SampleBeachBar1DGlobalState]:
        """
        Resets Mean Field distribution.
        """
        reset_rng = jax.random.split(key, self.params.n_agents)
        bar_loc_min = jnp.clip(jnp.floor(0.25 * self.params.num_states), 0, self.params.num_states - 1)
        bar_loc_max = jnp.clip(jnp.ceil(0.75 * self.params.num_states), bar_loc_min + 1, self.params.num_states)
        bar_loc = self.params.discrete_states[
            jax.random.randint(
                key,
                minval=bar_loc_min,
                maxval=bar_loc_max,
                shape=(),
            )
        ]
        # --- z is whether bar is open ---
        z = jnp.array(1, dtype=jnp.int32) 
        dummy_global_s = SampleBeachBar1DGlobalState(
            m=jnp.zeros(self.params.num_states), 
            z=z, 
            time=0, 
            bar_loc=bar_loc,
        )
        local_s = jax.vmap(self.sa_reset_env, in_axes=(0, None))(reset_rng, dummy_global_s)
        m = jnp.bincount(local_s.state, length=self.params.num_states) / self.params.n_agents
        m = m.at[bar_loc].set(0.0)
        m = m / jnp.sum(m)
        global_s = SampleBeachBar1DGlobalState(
            m=m, 
            z=z, 
            time=0, 
            bar_loc=bar_loc,
        )
        local_obs = jax.vmap(self.get_local_obs, in_axes=(0, None))(local_s, global_s)
        return local_obs, local_s, global_s


  def _single_idio_step(self, key: jax.Array, local_s: SampleBeachBar1DLocalState, action_idx: int, global_s: SampleBeachBar1DGlobalState) -> tuple[SampleBeachBar1DLocalState]:
        """
        Returns next local state with idiosyncratic noise for a current state, action and global state.
        """
        assert local_s.state.ndim == 0, "local_s must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"

        action = self.params.actions[action_idx]

        # --- step single agent forward ---
        deterministic_next_state_idx = self._single_step(local_s.state, action, global_s)

        # --- idiosyncratic noise ---
        delta = jax.random.choice(key, self.params.idio_atoms, p=self.params.idio_atoms_probs)
        delta = delta * jnp.asarray(self.params.idio_noise, dtype=delta.dtype)
        idio_next_state_idx = jnp.clip(deterministic_next_state_idx + delta, 0, self.params.num_states - 1).astype(jnp.int32)
        idio_next_state_idx = self._project_to_legal(local_s.state, idio_next_state_idx, global_s.bar_loc)

        # --- return next local state ---
        next_local_s = SampleBeachBar1DLocalState(state=idio_next_state_idx, time=local_s.time + 1)
        return next_local_s
  

  def _single_idio_reward(self, local_s: SampleBeachBar1DLocalState, action_idx: int, global_s: SampleBeachBar1DGlobalState, next_global_s: SampleBeachBar1DGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Returns reward for a current state, action and global state.
        """
        assert local_s.state.ndim == 0, "local_s must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"
 
        action = self.params.actions[action_idx]

        # --- calculate reward ---
        return self._single_reward(local_s.state, action, global_s, next_global_s)
    

  def sa_step_env(self, key: jax.Array, local_s: SampleBeachBar1DLocalState, action: int, global_s: SampleBeachBar1DGlobalState, next_global_s: SampleBeachBar1DGlobalState) -> tuple[SampleBeachBar1DLocalState, jax.Array, jax.Array]:
        """
        Unclosed step function for a single agent. Only moves one agent forward, so cannot return the updated global state or observation.
        """
        # --- step single agent forward ---
        next_local_s = self._single_idio_step(key, local_s, action, global_s)
        r_step, r_term = self._single_idio_reward(local_s, action, global_s, next_global_s)
        return (
          jax.lax.stop_gradient(next_local_s), 
          jax.lax.stop_gradient(r_step), 
          jax.lax.stop_gradient(r_term)
        ) 


  def sa_reset_env(self, key: jax.Array, global_s: SampleBeachBar1DGlobalState) -> SampleBeachBar1DLocalState:
        # --- NB: the mean-field may not have been set yet, but other GlobalState parameters should have been ---
        m = jnp.ones(self.params.num_states) / self.params.num_states
        m = m.at[global_s.bar_loc].set(0.0)
        m = m / jnp.sum(m)
        state = jax.random.choice(key, self.params.states, p=m).squeeze()
        return SampleBeachBar1DLocalState(state=state, time=0)


  def get_local_obs(self, local_s: SampleBeachBar1DLocalState, global_s: SampleBeachBar1DGlobalState) -> jax.Array:
        if self.params.partially_observable:
            mf_mean = jnp.sum(global_s.m * self.params.discrete_states)
            return jnp.concatenate([jnp.array([mf_mean]).reshape(-1), jnp.array([global_s.z]).reshape(-1), jnp.array([global_s.bar_loc]).reshape(-1)])
        return jnp.concatenate([global_s.m.reshape(-1), jnp.array([global_s.z]).reshape(-1), jnp.array([global_s.bar_loc]).reshape(-1), jnp.array(global_s.time).reshape(-1)])


  def normalize_obs(self, global_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        if self.params.partially_observable:
            # --- normalize location of mean of mean-field and bar location ---
            normalized_global_obs = global_obs.at[..., 0].set(1 - (global_obs[..., 0] / self.params.num_states))
            normalized_global_obs = global_obs.at[..., 2].set(1 - (global_obs[..., 2] / self.params.num_states))
            return normalized_global_obs
        else:
            # --- normalize bar location and time ---
            normalized_global_obs = global_obs.at[..., -2].set(1 - (global_obs[..., -2] / self.params.num_states))
            normalized_global_obs = global_obs.at[..., -1].set(1 - (global_obs[..., -1] / self.params.max_steps_in_episode))
            return jax.lax.select(normalize_obs, normalized_global_obs, global_obs.astype(jnp.float32))


  def normalize_local_s(self, local_states: jax.Array, normalize_states: bool = False) -> jax.Array:
        return local_states