from typing import TypeVar, Generic
from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp
from flax import struct

TGlobalState = TypeVar("TGlobalState", bound="BaseGlobalState")
TEnvParams = TypeVar("TEnvParams", bound="BaseEnvParams")


@struct.dataclass
class BaseGlobalState():
    # --- common noise, time ---
    z: jax.Array
    time: int


@struct.dataclass
class BaseMFSequence():
    global_s: BaseGlobalState
    global_terminated: jax.Array
    global_truncated: jax.Array


@struct.dataclass
class BaseEnvParams:
    max_steps_in_episode: int
    idio_noise: bool
    common_noise: bool


class BaseEnvironment(Generic[TGlobalState, TEnvParams], ABC):
    """ Abstract base class for all Model-Based Mean Field environments. """

    def __init__(self, params: BaseEnvParams):
        self.params = params

    @property
    def n_states(self) -> int:
        return self.params.states.shape[0]

    @property
    def state_indices(self) -> jax.Array:
        return jnp.arange(self.n_states)
        

    def _single_step(self, state: jax.Array, action: jax.Array, global_s: BaseGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Returns the next local state for a single agent with no idiosyncratic noise (i.e. deterministic step forward).
        """
        raise NotImplementedError
    

    def _single_reward(self, state: jax.Array, action: jax.Array, global_s: BaseGlobalState, next_global_s: BaseGlobalState) -> tuple[jax.Array, jax.Array]:
        """
        Calculates the (expected, if depends on next state) reward for a state, action and global state. Returns the reward for a step and the terminal reward.
        """
        raise NotImplementedError


    @abstractmethod
    def is_terminal(self, time: int) -> jax.Array:
        """Checks whether Mean Field is terminal (for finite horizon environments)."""
        raise NotImplementedError


    @abstractmethod
    def is_truncated(self, time: int) -> jax.Array:
        """Checks whether Mean Field is truncated (for infinite horizon environments)."""
        raise NotImplementedError


    def discount(self, global_s: BaseGlobalState) -> jax.Array:
        """Return zero discount if episode has terminated."""
        return jax.lax.select(self.is_terminal(global_s.time), 0.0, 1.0)