import jax
import jax.numpy as jnp
from flax import struct
from dataclasses import field

from mfax.envs.base.toy.beach_bar_1d import BaseBeachBar1DEnvParams, BaseBeachBar1DEnvironment, BaseBeachBar1DGlobalState
from mfax.envs.pushforward.base import PushforwardEnvParams, PushforwardEnvironment, PushforwardGlobalState


@struct.dataclass
class PushforwardBeachBar1DGlobalState(
    PushforwardGlobalState, 
    BaseBeachBar1DGlobalState
    ):
    pass


@struct.dataclass
class PushforwardBeachBar1DEnvParams(
    PushforwardEnvParams, 
    BaseBeachBar1DEnvParams
    ):
    # --- require default class so ordering stays valid under multiple inheritance ---
    states: jax.Array = field(default_factory=lambda: jnp.empty((0, 0))) 


class PushforwardBeachBar1DEnvironment(
    PushforwardEnvironment, 
    BaseBeachBar1DEnvironment
    ):

    @property
    def obs_dim(self) -> int:
        if self.params.partially_observable:
            return 3
        else:
            return self.params.num_states + 3


    def mf_reset_env(self, key: jax.Array) -> tuple[jax.Array, PushforwardBeachBar1DGlobalState]:
        m = jnp.ones(self.n_states) / self.n_states

        # --- sample bar location at IQR of num_states ---
        bar_loc_min = jnp.clip(jnp.floor(0.25 * self.params.num_states), 0, self.params.num_states - 1)
        bar_loc_max = jnp.clip(jnp.ceil(0.75 * self.params.num_states), bar_loc_min + 1, self.params.num_states)
        bar_loc = self.params.discrete_states[
            jax.random.randint(
                key,
                minval=bar_loc_min,
                maxval=bar_loc_max,
                shape=(),
            )
        ]
        m = m.at[bar_loc].set(0.0)
        m = m / jnp.sum(m)
        # --- z is whether bar is open ---
        z = jnp.array(1, dtype=jnp.int32) 
        global_s = PushforwardBeachBar1DGlobalState(m=m, z=z, time=0, bar_loc=bar_loc)
        return self.get_global_obs(global_s), global_s


    def mf_step_env(
        self, key: jax.Array, global_s: PushforwardBeachBar1DGlobalState, prob_a: jax.Array
    ) -> tuple[jax.Array, PushforwardBeachBar1DGlobalState, jax.Array, jax.Array, jax.Array]:
        next_m = self.mf_transition(global_s.m, prob_a, global_s)
        next_time = global_s.time + 1
        next_z = jax.lax.select(
            self.params.common_noise & (next_time == (self.params.max_steps_in_episode // 2)),
            jax.random.bernoulli(key).astype(jnp.int32),
            global_s.z.astype(jnp.int32),
        )
        next_global_s = PushforwardBeachBar1DGlobalState(
            m=next_m, 
            z=next_z, 
            time=next_time, 
            bar_loc=global_s.bar_loc
        )

        terminated = self.is_terminal(next_time)
        truncated = self.is_truncated(next_time)
        
        mat_r_step, mat_r_term = self.mf_reward(global_s, next_global_s)
        mat_r = jax.lax.select(terminated, mat_r_term, mat_r_step)
        return (
            jax.lax.stop_gradient(self.get_global_obs(next_global_s)),
            jax.lax.stop_gradient(next_global_s),
            jax.lax.stop_gradient(mat_r),
            jax.lax.stop_gradient(terminated),
            jax.lax.stop_gradient(truncated),
        )


    def _single_pushforward_step(
        self, 
        state: int, 
        action_idx: int, 
        global_s: PushforwardBeachBar1DGlobalState):
        """
        Returns next local state distribution due to idiosyncratic noise for a current state, action and global state.
        """
        assert state.ndim == 0, "state must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"

        action = self.params.actions[action_idx]

        # --- step single agent forward ---
        deterministic_next_state_idx = self._single_step(state, action, global_s)

        idio_next_state_idxs = jnp.clip(deterministic_next_state_idx + self.params.idio_atoms, 0, self.params.num_states - 1)
        idio_next_state_idxs = jax.vmap(self._project_to_legal, in_axes=(None, 0, None))(
            state, idio_next_state_idxs, global_s.bar_loc
        )

        next_state_idxs = jnp.concatenate([idio_next_state_idxs, jnp.array([deterministic_next_state_idx])], axis=0).astype(jnp.int32)
        probs = jnp.concatenate(
            [self.params.idio_atoms_probs * self.params.idio_noise, jnp.array([1.0 - self.params.idio_noise])], axis=0
        )
        probs = probs / jnp.where(probs.sum() > 0, probs.sum(), 1.0)
        return jax.lax.stop_gradient(next_state_idxs), jax.lax.stop_gradient(probs)


    def _single_pushforward_reward(
        self, state: int, action_idx: int, global_s: PushforwardBeachBar1DGlobalState, next_global_s: PushforwardBeachBar1DGlobalState
    ):
        """
        Returns next local state distribution due to idiosyncratic noise for a current state, action and global state.
        """
        assert state.ndim == 0, "state must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"

        action = self.params.actions[action_idx]

        # --- step single agent forward ---
        return self._single_reward(state, action, global_s, next_global_s)


    def get_global_obs(self, global_s: PushforwardBeachBar1DGlobalState) -> jax.Array:
        if self.params.partially_observable:
            mf_mean = jnp.sum(global_s.m * self.params.discrete_states)
            return jnp.concatenate([jnp.array([mf_mean]).reshape(-1), jnp.array([global_s.z]).reshape(-1), jnp.array([global_s.bar_loc]).reshape(-1)])
        return jnp.concatenate([global_s.m.reshape(-1), jnp.array([global_s.z]).reshape(-1), jnp.array([global_s.bar_loc]).reshape(-1), jnp.array(global_s.time).reshape(-1)])


    def normalize_obs(self, global_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        if self.params.partially_observable:
            # --- normalize location of mean of mean-field and bar location ---
            normalized_global_obs = global_obs.at[..., 0].set(1 - (global_obs[..., 0] / self.params.num_states))
            normalized_global_obs = global_obs.at[..., 2].set(1 - (global_obs[..., 2] / self.params.num_states))
            return normalized_global_obs
        else:
            # --- normalize bar location and time ---
            normalized_global_obs = global_obs.at[..., -2].set(1 - (global_obs[..., -2] / self.params.num_states))
            normalized_global_obs = global_obs.at[..., -1].set(1 - (global_obs[..., -1] / self.params.max_steps_in_episode))
            return jax.lax.select(normalize_obs, normalized_global_obs, global_obs.astype(jnp.float32))


    def normalize_local_s(self, local_states: jax.Array, normalize_states: bool = False) -> jax.Array:
        return local_states