import jax
import jax.numpy as jnp
from flax import struct

from mfax.envs.base.toy.linear_quadratic import BaseLinearQuadraticEnvParams, BaseLinearQuadraticEnvironment, BaseLinearQuadraticGlobalState
from mfax.envs.sample.base import SampleEnvironment, SampleEnvParams, SampleLocalState, SampleGlobalState


@struct.dataclass
class SampleLinearQuadraticLocalState(
      SampleLocalState
      ):
  pass


@struct.dataclass
class SampleLinearQuadraticGlobalState(
      SampleGlobalState, 
      BaseLinearQuadraticGlobalState
      ):
      m: jax.Array


@struct.dataclass
class SampleLinearQuadraticEnvParams(
      SampleEnvParams, 
      BaseLinearQuadraticEnvParams
      ):
      # number of agents representing mean field
      n_agents: int = 10000


class SampleLinearQuadraticEnvironment(
      SampleEnvironment, 
      BaseLinearQuadraticEnvironment
      ):
      
  @property
  def obs_dim(self) -> int:
    if self.params.partially_observable:
      return 1
    else:
      return self.params.num_states + 2

  def mf_step_env(
        self,
        key: jax.Array,
        local_s: SampleLinearQuadraticLocalState,
        global_s: SampleLinearQuadraticGlobalState,
        vec_a: jax.Array,
    ) -> tuple[jax.Array, SampleLinearQuadraticLocalState, SampleLinearQuadraticGlobalState, jax.Array, jax.Array, jax.Array]:
        """
        Steps environment forward for a given global state and vector of actions for each agent.
        """
        step_rng = jax.random.split(key, self.params.n_agents)
        next_local_s = jax.vmap(self._single_idio_step, in_axes=(0, 0, 0, None))(step_rng, local_s, vec_a, global_s)
        next_m = jnp.bincount(next_local_s.state, length=self.params.num_states) / self.params.n_agents
        next_time = global_s.time + 1
        next_global_s = SampleLinearQuadraticGlobalState(m=next_m, z=global_s.z, time=next_time)
        next_local_obs = jax.vmap(self.get_local_obs, in_axes=(0, None))(next_local_s, next_global_s)
        
        # --- check for termination and truncation ---
        terminated = self.is_terminal(next_time)
        truncated = self.is_truncated(next_time)

        # --- select between step and terminated reward ---
        vec_r_term, vec_r_st = jax.vmap(self._single_idio_reward, in_axes=(0, 0, None, None))(local_s, vec_a, global_s, next_global_s)
        vec_r = jax.lax.select(terminated, vec_r_term, vec_r_st)
        return (
          jax.lax.stop_gradient(next_local_obs), 
          jax.lax.stop_gradient(next_local_s), 
          jax.lax.stop_gradient(next_global_s), 
          jax.lax.stop_gradient(vec_r), 
          jax.lax.stop_gradient(terminated), 
          jax.lax.stop_gradient(truncated)
        )


  def mf_reset_env(
        self, 
        key: jax.Array
    ) -> tuple[jax.Array, SampleLinearQuadraticLocalState, SampleLinearQuadraticGlobalState]:
        """
        Resets Mean Field distribution.
        """
        reset_rng = jax.random.split(key, self.params.n_agents)
        dummy_m = jnp.zeros(self.params.num_states)
        z = jax.lax.select(self.params.common_noise, jax.lax.select(jax.random.bernoulli(key), 1, -1), 0)
        dummy_global_s = SampleLinearQuadraticGlobalState(m=dummy_m, z=z, time=0)
        local_s = jax.vmap(self.sa_reset_env, in_axes=(0, None))(reset_rng, dummy_global_s)
        m = jnp.bincount(local_s.state, length=self.params.num_states) / self.params.n_agents
        global_s = SampleLinearQuadraticGlobalState(m=m, z=z, time=0)
        local_obs = jax.vmap(self.get_local_obs, in_axes=(0, None))(local_s, global_s)
        return local_obs, local_s, global_s


  def _single_idio_step(self, key: jax.Array, local_s: SampleLinearQuadraticLocalState, action_idx: int, global_s: SampleLinearQuadraticGlobalState) -> tuple[SampleLinearQuadraticLocalState]:
        """
        Returns next local state with idiosyncratic noise for a current state, action and global state.
        """
        assert local_s.state.ndim == 0, "local_s must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"

        action = self.params.actions[action_idx]

        # --- step single agent forward ---
        deterministic_next_state_idx = self._single_step(local_s.state, action, global_s)

        # --- idiosyncratic noise ---
        delta = jax.random.choice(key, self.params.idio_atoms, p=self.params.idio_atoms_probs)
        delta = delta * jnp.asarray(self.params.idio_noise, dtype=delta.dtype)
        idio_next_state_idx = jnp.clip(deterministic_next_state_idx + delta, 0, self.params.num_states - 1).astype(jnp.int32)

        # --- return next local state ---
        next_local_s = SampleLinearQuadraticLocalState(state=idio_next_state_idx, time=local_s.time + 1)
        return next_local_s
  

  def _single_idio_reward(self, local_s: SampleLinearQuadraticLocalState, action_idx: int, global_s: SampleLinearQuadraticGlobalState, next_global_s: SampleLinearQuadraticGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Returns reward for a current state, action and global state.
        """
        assert local_s.state.ndim == 0, "local_s must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"
 
        action = self.params.actions[action_idx]

        # --- calculate reward ---
        return self._single_reward(local_s.state, action, global_s, next_global_s)
    

  def sa_step_env(self, key: jax.Array, local_s: SampleLinearQuadraticLocalState, action: int, global_s: SampleLinearQuadraticGlobalState, next_global_s: SampleLinearQuadraticGlobalState) -> tuple[SampleLinearQuadraticLocalState, jax.Array, jax.Array]:
        """
        Unclosed step function for a single agent. Only moves one agent forward, so cannot return the updated global state or observation.
        """
        # --- step single agent forward ---
        next_local_s = self._single_idio_step(key, local_s, action, global_s)
        r_step, r_term = self._single_idio_reward(local_s, action, global_s, next_global_s)
        return (
          jax.lax.stop_gradient(next_local_s), 
          jax.lax.stop_gradient(r_step), 
          jax.lax.stop_gradient(r_term)
        ) 


  def sa_reset_env(self, key: jax.Array, global_s: SampleLinearQuadraticGlobalState) -> SampleLinearQuadraticLocalState:
        m = jnp.ones(self.params.num_states) / self.params.num_states
        state = jax.random.choice(key, self.params.states, p=m).squeeze()
        return SampleLinearQuadraticLocalState(state=state, time=0)


  def get_local_obs(self, local_s: SampleLinearQuadraticLocalState, global_s: SampleLinearQuadraticGlobalState) -> jax.Array:
        if self.params.partially_observable:
            mf_mean = jnp.sum(global_s.m * self.params.discrete_states)
            return jnp.array([mf_mean]).reshape(-1)
        return jnp.concatenate([global_s.m.reshape(-1), global_s.z.reshape(-1), jnp.array(global_s.time).reshape(-1)])
  

  def normalize_obs(self, local_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        if self.params.partially_observable:
            normalized_local_obs = local_obs / self.params.num_states
            return normalized_local_obs
        else:
            normalized_local_obs = local_obs.at[..., -1].set(1 - (local_obs[..., -1] / self.params.max_steps_in_episode))
            return jax.lax.select(normalize_obs, normalized_local_obs, local_obs.astype(jnp.float32))


  def normalize_local_s(self, local_states: jax.Array, normalize_states: bool = False) -> jax.Array:
        return local_states