import jax
import jax.numpy as jnp
from flax import struct

from mfax.envs.base.macro.endogenous import BaseEndogenousEnvironment, BaseEndogenousEnvParams, BaseEndogenousGlobalState
from mfax.envs.sample.base import SampleEnvironment, SampleEnvParams, SampleLocalState, SampleGlobalState


@struct.dataclass
class SampleEndogenousLocalState(
    SampleLocalState
    ):
  pass


@struct.dataclass
class SampleEndogenousGlobalState(
    SampleGlobalState, 
    BaseEndogenousGlobalState
    ):
  pass


@struct.dataclass
class SampleEndogenousEnvParams(
    SampleEnvParams, 
    BaseEndogenousEnvParams
    ):

  # number of agents representing mean field
  n_agents: int = 10000

  # states per dimension
  lower_bound: jax.Array = jnp.array([0.0, 0.1])
  upper_bound: jax.Array = jnp.array([99, 2.0])
  num_states: tuple[int, int] = (200, 5)

  # idiosyncratic noise parameters
  idio_atoms: jax.Array = jnp.array([-1, 0, 1])  
  idio_atoms_probs: jax.Array = jnp.array([0.1, 0.8, 0.1])

  def __post_init__(self):
        BaseEndogenousEnvParams.__post_init__(self)
        pivots = jnp.maximum(jnp.abs(self.lower_bound), 0.25)
        discrete_states = [jnp.clip(jnp.geomspace(lb + p, ub + p, ns) - p, a_min=0) for lb, ub, ns, p in zip(self.lower_bound, self.upper_bound, self.num_states, pivots)]
        states = jnp.stack([sms.ravel() for sms in jnp.meshgrid(*discrete_states, indexing="ij")], axis=1)
        object.__setattr__(self, "pivots", pivots)
        object.__setattr__(self, "discrete_states", discrete_states)
        object.__setattr__(self, "states", states)
        object.__setattr__(self, "n_states", len(states))


class SampleEndogenousEnvironment(
    SampleEnvironment, 
    BaseEndogenousEnvironment
    ):

    def _prices(self, local_s: SampleEndogenousLocalState, z: float) -> tuple[float, float]:
        assert local_s.state.ndim == 2, "Local state must be 2D array of shape (num_agents, num_state_dims)"
        # --- sum over nuisance variable to obtain marginal distributions --- 
        av_wealth = jnp.mean(local_s.state[:, 0])
        av_income = jnp.mean(local_s.state[:, 1])

        # --- prices --- 
        interest_rate = jnp.exp(z) * self.params.cobb_douglas_alpha * ((av_income / (av_wealth + 1e-6)) ** (1 - self.params.cobb_douglas_alpha))
        wage = jnp.exp(z) * (1 - self.params.cobb_douglas_alpha) * ((av_wealth / (av_income + 1e-6)) ** self.params.cobb_douglas_alpha)
        return interest_rate, wage


    def mf_step_env(
        self,
        key: jax.Array,
        local_s: SampleEndogenousLocalState,
        global_s: SampleEndogenousGlobalState,
        vec_a: jax.Array,
    ) -> tuple[jax.Array, SampleEndogenousLocalState, SampleEndogenousGlobalState, jax.Array, jax.Array, jax.Array]:
        """
        Steps environment forward for a given global state and vector of actions for each agent.
        """
        step_rng = jax.random.split(key, self.params.n_agents)
        next_local_s = jax.vmap(self._single_idio_step, in_axes=(0, 0, 0, None))(step_rng, local_s, vec_a, global_s)
        next_z = self.params.rho * global_s.z + self.params.nu * jax.random.normal(key)
        next_interest_rate, next_wage = self._prices(next_local_s, next_z)
        next_time = global_s.time + 1
        next_global_s = SampleEndogenousGlobalState(z=next_z, time=next_time, interest_rate=next_interest_rate, wage=next_wage)
        next_local_obs = jax.vmap(self.get_local_obs, in_axes=(0, None))(next_local_s, next_global_s)
        
        # --- check for termination and truncation ---
        terminated = self.is_terminal(next_time)
        truncated = self.is_truncated(next_time)

        # --- select between step and terminated reward ---
        vec_r_term, vec_r_st = jax.vmap(self._single_idio_reward, in_axes=(0, 0, None, None))(local_s, vec_a, global_s, next_global_s)
        vec_r = jax.lax.select(terminated, vec_r_term, vec_r_st)
        return (
          jax.lax.stop_gradient(next_local_obs), 
          jax.lax.stop_gradient(next_local_s), 
          jax.lax.stop_gradient(next_global_s), 
          jax.lax.stop_gradient(vec_r), 
          jax.lax.stop_gradient(terminated), 
          jax.lax.stop_gradient(truncated)
        )


    def mf_reset_env(
        self, 
        key: jax.Array
    ) -> tuple[jax.Array, SampleEndogenousLocalState, SampleEndogenousGlobalState]:
        """
        Resets Mean Field distribution.
        """
        reset_rng = jax.random.split(key, self.params.n_agents)
        z = 0.0
        dummy_global_s = SampleEndogenousGlobalState(z=z, time=0, interest_rate=0, wage=0)
        local_s = jax.vmap(self.sa_reset_env, in_axes=(0, None))(reset_rng, dummy_global_s)
        interest_rate, wage = self._prices(local_s, z)
        global_s = SampleEndogenousGlobalState(z=z, time=0, interest_rate=interest_rate, wage=wage)
        local_obs = jax.vmap(self.get_local_obs, in_axes=(0, None))(local_s, global_s)
        return local_obs, local_s, global_s


    def _single_idio_step(self, key: jax.Array, local_s: SampleEndogenousLocalState, action: jax.Array, global_s: SampleEndogenousGlobalState) -> tuple[SampleEndogenousLocalState]:
        """
        Returns next local state with idiosyncratic noise for a current state, action and global state.
        """
        assert local_s.state.ndim == 1, "local_s must be a 1D array of shape (2,)"
        assert action.ndim in (0, 1), f"action ndim ({action.ndim}) must be 0 or 1"
        
        # --- convert to (clipped) continuous action ---
        if action.ndim == 0:
            action = self.params.discrete_actions[action]
        else:
            action = jnp.clip(action.squeeze(), 0.0, 1.0)

        # --- step single agent forward ---
        deterministic_next_state = self._single_step(local_s.state, action, global_s)
        next_wealth = deterministic_next_state[0]

        # --- idiosyncratic noise ---
        income = local_s.state[1]
        income_idx = jnp.argmin(jnp.abs(self.params.discrete_states[1] - income)).astype(jnp.int32)
        delta = jax.random.choice(key, self.params.idio_atoms, p=self.params.idio_atoms_probs)
        delta = delta * jnp.asarray(self.params.idio_noise, dtype=delta.dtype)
        idio_next_income_idx = jnp.clip(income_idx + delta, 0, self.params.num_states[1] - 1).astype(jnp.int32)
        idio_next_income = self.params.discrete_states[1][idio_next_income_idx]

        # --- return next local state ---
        next_local_s = SampleEndogenousLocalState(state=jnp.array([next_wealth, idio_next_income]))
        return next_local_s
    

    def _single_idio_reward(self, local_s: SampleEndogenousLocalState, action: jax.Array, global_s: SampleEndogenousGlobalState, next_global_s: SampleEndogenousGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Returns reward for a current state, action and global state.
        """
        assert local_s.state.ndim == 1, "local_s must be a 1D array of shape (2,)"
        assert action.ndim in (0, 1), f"action ndim ({action.ndim}) must be 0 or 1"

        # --- convert to (clipped) continuous action ---
        if action.ndim == 0:
            action = self.params.discrete_actions[action]
        else:
            action = jnp.clip(action.squeeze(), 0.0, 1.0)

        # --- calculate reward ---
        return self._single_reward(local_s.state, action, global_s, next_global_s)


    def sa_step_env(self, key: jax.Array, local_s: SampleEndogenousLocalState, action: jax.Array, global_s: SampleEndogenousGlobalState, next_global_s: SampleEndogenousGlobalState) -> tuple[SampleEndogenousLocalState, jax.Array, jax.Array]:
        """
        Unclosed step function for a single agent. Only moves one agent forward, so cannot return the updated global state or observation.
        """
        # --- step single agent forward ---
        next_local_s = self._single_idio_step(key, local_s, action, global_s)
        r_step, r_term = self._single_idio_reward(local_s, action, global_s, next_global_s)
        return (
          jax.lax.stop_gradient(next_local_s), 
          jax.lax.stop_gradient(r_step), 
          jax.lax.stop_gradient(r_term)
        ) 


    def sa_reset_env(self, key: jax.Array, global_s: SampleEndogenousGlobalState) -> SampleEndogenousLocalState:
        m = jnp.ones(self.params.n_states) / self.params.n_states
        state = jax.random.choice(key, self.params.states, p=m)
        return SampleEndogenousLocalState(state=state, time=0)


    def get_local_obs(self, local_s: SampleEndogenousLocalState, global_s: SampleEndogenousGlobalState) -> jax.Array:
        if self.params.partially_observable:
            return jnp.array([global_s.interest_rate, global_s.wage])
        else:
            return jnp.array([global_s.interest_rate, global_s.wage, global_s.time])
    

    def normalize_obs(self, global_obs: jax.Array, normalize_obs: bool = False) -> jax.Array:
        """
        Transform global observation for feeding into policy network. Must work on batched observations.
        """
        if self.params.partially_observable:
            return global_obs
        else:
            # --- mean field distribution does not need normalising, only time [and, potentially, common noise] ---
            normalized_global_obs = global_obs.at[..., -1].set(1 - (global_obs[..., -1] / self.params.max_steps_in_episode))
            return jax.lax.select(normalize_obs, normalized_global_obs, global_obs.astype(jnp.float32))


    def normalize_local_s(self, local_states: jax.Array, normalize_states: bool = False) -> jax.Array:
        """
        Transform local state for feeding into policy network. Must work on batched observations.
        """
        if not normalize_states:
            return local_states

        D = self.params.pivots.shape[0]
        assert local_states.shape[-1] == D, f"expected last dimension to be {D}, got {local_states.shape[-1]}"

        # --- geometric normalization ---
        ratio = (self.params.upper_bound + self.params.pivots) / (self.params.lower_bound + self.params.pivots)
        ratio_is_small = jnp.isclose(ratio, 1.0)
        x_shifted = local_states + self.params.pivots
        frac = jnp.clip(x_shifted / (self.params.lower_bound + self.params.pivots), 1e-12, None)
        geom_denom = jnp.where(ratio_is_small, 1.0, jnp.log(ratio))
        geom_u = jnp.log(frac) / geom_denom

        # --- use linear normalization when geometric formula is ill conditioned ---
        width = (self.params.upper_bound - self.params.lower_bound)
        safe_width = jnp.where(width == 0, 1.0, width)
        linear_raw = (local_states - self.params.lower_bound) / safe_width
        linear_u = jnp.where(jnp.isclose(self.params.upper_bound, self.params.lower_bound), 0.0, linear_raw)

        # --- pick per-dimension formula and clip to [0, 1] ---
        u = jnp.where(ratio_is_small, linear_u, geom_u)
        return jnp.clip(u, 0.0, 1.0)