"""Implementation of a Jax-accelerated cartpole environment."""
from __future__ import annotations

from typing import Any, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey

import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import (
    FunctionalJaxEnv,
    FunctionalJaxVectorEnv,
)
from gymnasium.utils import EzPickle


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"]  # type: ignore  # noqa: F821


class CartPoleFunctional(
    FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
):
    """Cartpole but in jax and functional.

    Example:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
        >>> key = jax.random.PRNGKey(0)
        >>> env = CartPoleFunctional({"x_init": 0.5})
        >>> state = env.initial(key)
        >>> print(state)
        [ 0.46532142 -0.27484107  0.13302994 -0.20361817]
        >>> print(env.transition(state, 0))
        [ 0.4598246  -0.6357784   0.12895757  0.1278053 ]
        >>> env.transform(jax.jit)
        >>> state = env.initial(key)
        >>> print(state)
        [ 0.46532142 -0.27484107  0.13302994 -0.20361817]
        >>> print(env.transition(state, 0))
        [ 0.4598246  -0.6357784   0.12895757  0.12780523]
        >>> vkey = jax.random.split(key, 10)
        >>> env.transform(jax.vmap)
        >>> vstate = env.initial(vkey)
        >>> print(vstate)
        [[ 0.25117755 -0.03159595  0.09428263  0.12404168]
         [ 0.231457    0.41420317 -0.13484478  0.29151905]
         [-0.11706758 -0.37130308  0.13587534  0.33141208]
         [-0.4613737   0.36557996  0.3950702   0.3639989 ]
         [-0.14707637 -0.34273267 -0.32374108 -0.48110402]
         [-0.45774353  0.3633288  -0.3157575  -0.03586268]
         [ 0.37344885 -0.279778   -0.33894253  0.07415426]
         [-0.20234215  0.39775252 -0.2556088   0.32877135]
         [-0.2572986  -0.29943776 -0.45600426 -0.35740316]
         [ 0.05436695  0.35021234 -0.36484408  0.2805779 ]]
        >>> print(env.transition(vstate, jnp.array([0 for _ in range(10)])))
        [[ 0.25054562 -0.38763174  0.09676346  0.4448946 ]
         [ 0.23974106  0.09849604 -0.1290144   0.5390002 ]
         [-0.12449364 -0.7323911   0.14250359  0.6634313 ]
         [-0.45406207 -0.01028753  0.4023502   0.7505522 ]
         [-0.15393102 -0.6168968  -0.33336315 -0.30407968]
         [-0.45047694  0.08870795 -0.31647477  0.14311607]
         [ 0.36785328 -0.54895645 -0.33745944  0.24393772]
         [-0.19438711  0.10855066 -0.24903338  0.5316877 ]
         [-0.26328734 -0.5420943  -0.46315232 -0.2344252 ]
         [ 0.06137119  0.08665388 -0.35923252  0.4403924 ]]
    """

    gravity = 9.8
    masscart = 1.0
    masspole = 0.1
    total_mass = masspole + masscart
    length = 0.5
    polemass_length = masspole + length
    force_mag = 10.0
    tau = 0.02
    theta_threshold_radians = 12 * 2 * np.pi / 360
    x_threshold = 2.4
    x_init = 0.05

    screen_width = 600
    screen_height = 400

    observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
    action_space = gym.spaces.Discrete(2)

    def initial(self, rng: PRNGKey):
        """Initial state generation."""
        return jax.random.uniform(
            key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
        )

    def transition(
        self, state: jax.Array, action: int | jax.Array, rng: None = None
    ) -> StateType:
        """Cartpole transition."""
        x, x_dot, theta, theta_dot = state
        force = jnp.sign(action - 0.5) * self.force_mag
        costheta = jnp.cos(theta)
        sintheta = jnp.sin(theta)

        # For the interested reader:
        # https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (
            force + self.polemass_length * theta_dot**2 * sintheta
        ) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
        )
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        x = x + self.tau * x_dot
        x_dot = x_dot + self.tau * xacc
        theta = theta + self.tau * theta_dot
        theta_dot = theta_dot + self.tau * thetaacc

        state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)

        return state

    def observation(self, state: jax.Array) -> jax.Array:
        """Cartpole observation."""
        return state

    def terminal(self, state: jax.Array) -> jax.Array:
        """Checks if the state is terminal."""
        x, _, theta, _ = state

        terminated = (
            (x < -self.x_threshold)
            | (x > self.x_threshold)
            | (theta < -self.theta_threshold_radians)
            | (theta > self.theta_threshold_radians)
        )

        return terminated

    def reward(
        self, state: StateType, action: ActType, next_state: StateType
    ) -> jax.Array:
        """Computes the reward for the state transition using the action."""
        x, _, theta, _ = state

        terminated = (
            (x < -self.x_threshold)
            | (x > self.x_threshold)
            | (theta < -self.theta_threshold_radians)
            | (theta > self.theta_threshold_radians)
        )

        reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
        return reward

    def render_image(
        self,
        state: StateType,
        render_state: RenderStateType,
    ) -> tuple[RenderStateType, np.ndarray]:
        """Renders an image of the state using the render state."""
        try:
            import pygame
            from pygame import gfxdraw
        except ImportError as e:
            raise DependencyNotInstalled(
                "pygame is not installed, run `pip install gymnasium[classic-control]`"
            ) from e
        screen, clock = render_state

        world_width = self.x_threshold * 2
        scale = self.screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        x = state

        surf = pygame.Surface((self.screen_width, self.screen_height))
        surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + self.screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))

        surf = pygame.transform.flip(surf, False, True)
        screen.blit(surf, (0, 0))

        return (screen, clock), np.transpose(
            np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
        )

    def render_init(
        self, screen_width: int = 600, screen_height: int = 400
    ) -> RenderStateType:
        """Initialises the render state for a screen width and height."""
        try:
            import pygame
        except ImportError as e:
            raise DependencyNotInstalled(
                "pygame is not installed, run `pip install gymnasium[classic-control]`"
            ) from e

        pygame.init()
        screen = pygame.Surface((screen_width, screen_height))
        clock = pygame.time.Clock()

        return screen, clock

    def render_close(self, render_state: RenderStateType) -> None:
        """Closes the render state."""
        try:
            import pygame
        except ImportError as e:
            raise DependencyNotInstalled(
                "pygame is not installed, run `pip install gymnasium[classic-control]`"
            ) from e
        pygame.display.quit()
        pygame.quit()


class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
    """Jax-based implementation of the CartPole environment."""

    metadata = {"render_modes": ["rgb_array"], "render_fps": 50}

    def __init__(self, render_mode: str | None = None, **kwargs: Any):
        """Constructor for the CartPole where the kwargs are applied to the functional environment."""
        EzPickle.__init__(self, render_mode=render_mode, **kwargs)

        env = CartPoleFunctional(**kwargs)
        env.transform(jax.jit)

        FunctionalJaxEnv.__init__(
            self,
            env,
            metadata=self.metadata,
            render_mode=render_mode,
        )


class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
    """Jax-based implementation of the vectorized CartPole environment."""

    metadata = {"render_modes": ["rgb_array"], "render_fps": 50}

    def __init__(
        self,
        num_envs: int,
        render_mode: str | None = None,
        max_episode_steps: int = 200,
        **kwargs: Any,
    ):
        """Constructor for the vectorized CartPole where the kwargs are applied to the functional environment."""
        EzPickle.__init__(
            self,
            num_envs=num_envs,
            render_mode=render_mode,
            max_episode_steps=max_episode_steps,
            **kwargs,
        )

        env = CartPoleFunctional(**kwargs)
        env.transform(jax.jit)

        FunctionalJaxVectorEnv.__init__(
            self,
            func_env=env,
            num_envs=num_envs,
            metadata=self.metadata,
            render_mode=render_mode,
            max_episode_steps=max_episode_steps,
        )
