import jax
import jax.numpy as jnp
from flax import struct
from dataclasses import field

from mfax.envs.base.toy.linear_quadratic import BaseLinearQuadraticEnvParams, BaseLinearQuadraticEnvironment, BaseLinearQuadraticGlobalState
from mfax.envs.pushforward.base import PushforwardEnvParams, PushforwardEnvironment, PushforwardGlobalState


@struct.dataclass
class PushforwardLinearQuadraticGlobalState(
    PushforwardGlobalState, 
    BaseLinearQuadraticGlobalState
    ):
    pass


@struct.dataclass
class PushforwardLinearQuadraticEnvParams(
    PushforwardEnvParams, 
    BaseLinearQuadraticEnvParams
    ):
    # --- require default class so ordering stays valid under multiple inheritance ---
    states: jax.Array = field(default_factory=lambda: jnp.empty((0, 0)))


class PushforwardLinearQuadraticEnvironment(
    PushforwardEnvironment, 
    BaseLinearQuadraticEnvironment
    ):

    @property
    def obs_dim(self) -> int:
        if self.params.partially_observable:
            return 1
        else:
            return self.params.num_states + 2

    def mf_reset_env(self, key: jax.Array) -> tuple[jax.Array, PushforwardLinearQuadraticGlobalState]:
        m = jnp.ones(self.n_states) / self.n_states
        z = jax.lax.select(self.params.common_noise, jax.lax.select(jax.random.bernoulli(key), 1, -1), 0)
        global_s = PushforwardLinearQuadraticGlobalState(m=m, z=z, time=0)
        return self.get_global_obs(global_s), global_s

    def mf_step_env(
        self, key: jax.Array, global_s: PushforwardLinearQuadraticGlobalState, prob_a: jax.Array
    ) -> tuple[jax.Array, PushforwardLinearQuadraticGlobalState, jax.Array, jax.Array, jax.Array]:
        next_m = self.mf_transition(global_s.m, prob_a, global_s)
        next_time = global_s.time + 1
        next_global_s = PushforwardLinearQuadraticGlobalState(m=next_m, z=global_s.z, time=next_time)

        terminated = self.is_terminal(next_time)
        truncated = self.is_truncated(next_time)
        
        mat_r_step, mat_r_term = self.mf_reward(global_s, next_global_s)
        mat_r = jax.lax.select(terminated, mat_r_term, mat_r_step)
        return (
            jax.lax.stop_gradient(self.get_global_obs(next_global_s)),
            jax.lax.stop_gradient(next_global_s),
            jax.lax.stop_gradient(mat_r),
            jax.lax.stop_gradient(terminated),
            jax.lax.stop_gradient(truncated),
        )


    def _single_pushforward_step(self, state: int, action_idx: int, global_s: PushforwardLinearQuadraticGlobalState):
        """
        Returns next local state distribution due to idiosyncratic noise for a current state, action and global state.
        """
        assert state.ndim == 0, "state must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"

        action = self.params.actions[action_idx]

        # --- step single agent forward ---
        deterministic_next_state_idx = self._single_step(state, action, global_s)

        idio_scale = self.params.sigma * jnp.sqrt(1.0 - (self.params.rho**2))
        idio_next_state_idxs = jnp.clip(jnp.round(deterministic_next_state_idx + idio_scale * self.params.idio_atoms).astype(jnp.int32), 0, self.params.num_states - 1)

        next_state_idxs = jnp.concatenate([idio_next_state_idxs, jnp.array([deterministic_next_state_idx])], axis=0)
        probs = jnp.concatenate(
            [self.params.idio_atoms_probs * self.params.idio_noise, jnp.array([1.0 - self.params.idio_noise])], axis=0
        )
        probs = probs / jnp.where(probs.sum() > 0, probs.sum(), 1.0)
        return jax.lax.stop_gradient(next_state_idxs), jax.lax.stop_gradient(probs)


    def _single_pushforward_reward(
        self, state: int, action_idx: int, global_s: PushforwardLinearQuadraticGlobalState, next_global_s: PushforwardLinearQuadraticGlobalState
    ):
        """
        Returns next local state distribution due to idiosyncratic noise for a current state, action and global state.
        """
        assert state.ndim == 0, "state must be an integer"
        assert action_idx.ndim == 0, f"action_idx ndim ({action_idx.ndim}) must be 0"

        action = self.params.actions[action_idx]

        # --- step single agent forward ---
        return self._single_reward(state, action, global_s, next_global_s)


    def get_global_obs(self, global_s: PushforwardLinearQuadraticGlobalState) -> jax.Array:
        if self.params.partially_observable:
            mf_mean = jnp.sum(global_s.m * self.params.discrete_states)
            return jnp.array([mf_mean]).reshape(-1)
        return jnp.concatenate([global_s.m.reshape(-1), jnp.array([global_s.z]).reshape(-1), jnp.array(global_s.time).reshape(-1)])


    def normalize_obs(self, global_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        if self.params.partially_observable:
            normalized_global_obs = global_obs / self.params.num_states
            return normalized_global_obs
        else:
            normalized_global_obs = global_obs.at[..., -1].set(1 - (global_obs[..., -1] / self.params.max_steps_in_episode))
            return jax.lax.select(normalize_obs, normalized_global_obs, global_obs.astype(jnp.float32))


    def normalize_local_s(self, local_states: jax.Array, normalize_states: bool = False) -> jax.Array:
        return local_states

