from functools import cached_property

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jumanji import specs
from jumanji.env import Environment
from jumanji.types import TimeStep, restart, termination, transition

DELTAS = jnp.array([[-1, 0], [1, 0], [0, -1], [0, 1]], jnp.float32)
BEACON_LOCS = jnp.array([[-2, -2], [0, 2.8], [2, -2]], jnp.float32)
GOAL = jnp.array([0, 0], jnp.float32)
START_BOX = jnp.array([[-5, -5], [5, 5]], jnp.float32)
STEP_REWARD = -0.05
MAX_REWARD = 1.0
NUM_ACTIONS = 6  # up, down, left, right, stop, sense


class TriangulationState(eqx.Module):
    player_location: jax.Array
    active_beacon: int
    step_count: int
    terminal: bool

    @property
    def max_feature_value(self) -> int:
        raise ValueError('Triangulation environment has continuous states')

    def to_flat_features(self) -> jax.Array:
        return jnp.concatenate(
            [self.player_location, jnp.expand_dims(self.active_beacon, 0)], axis=0
        )

    def from_flat_features(self, features: jax.Array):
        return TriangulationState(
            player_location=features[:2],
            active_beacon=features[2].astype(int),
            step_count=self.step_count,
            terminal=self.terminal,
        )

    def observation_from_state(self):
        action_mask = jnp.ones(NUM_ACTIONS, dtype=jnp.int32)
        return TriangulationObservation(distance=-1, action_mask=action_mask)


# ordinal values for distances to beacons
class TriangulationObservation(eqx.Module):
    distance: int
    action_mask: jax.Array  # (4 + nbeacons + 1,) binary mask of valid actions


class TriangulationGenerator:
    def __call__(self, key: jax.Array) -> TriangulationState:
        loc_key, beacon_key = jax.random.split(key, 2)
        player_location = jax.random.uniform(
            loc_key, minval=START_BOX[0], maxval=START_BOX[1], shape=(2,), dtype=jnp.float32
        )
        active_beacon = jax.random.randint(beacon_key, (), 0, BEACON_LOCS.shape[0])
        return TriangulationState(
            player_location=player_location,
            active_beacon=active_beacon,
            step_count=0,
            terminal=False,
        )


class Triangulation(Environment[TriangulationState, specs.DiscreteArray, TriangulationObservation]):
    def __init__(
        self,
        step_size: float = 0.5,
        movement_noise: float = 0.01,
        sensor_noise: float = 0.1,
        n_obs_bins: int = 20,
        time_limit: int = 100,
    ):
        self.step_size = step_size
        self.movement_noise = movement_noise
        self.sensor_noise = sensor_noise
        self.time_limit = time_limit
        self.n_obs_bins = n_obs_bins
        self.generator = TriangulationGenerator()
        self.has_state_observation = False
        self.has_continuous_features = True
        super().__init__()

    @cached_property
    def observation_spec(self) -> specs.Spec[TriangulationObservation]:
        # distance to beacons are binned into n_obs_bins ordinal values
        # 0: <1, 1: <2, 2: <3,...
        distance = specs.DiscreteArray(self.n_obs_bins, name='distance', dtype=jnp.int32)
        action_mask = specs.BoundedArray(
            shape=(NUM_ACTIONS,),
            dtype=jnp.int32,
            name='action_mask',
            minimum=0,
            maximum=1,
        )
        return specs.Spec(
            TriangulationObservation,
            'TriangulationObservationSpec',
            distance=distance,
            action_mask=action_mask,
        )

    @cached_property
    def action_spec(self) -> specs.DiscreteArray:
        """
        Moving up, down, left, and right, stopping and sensing
        """
        return specs.DiscreteArray(NUM_ACTIONS, name='action')

    def reset(
        self, key: jax.Array
    ) -> tuple[TriangulationState, TimeStep[TriangulationObservation]]:
        state = self.generator(key)
        action_mask = jnp.ones(self.action_spec.num_values, dtype=jnp.int32)
        obs = TriangulationObservation(distance=-1, action_mask=action_mask)
        timestep = restart(observation=obs)
        return state, timestep

    def _scan_distance_mu(self, state: TriangulationState) -> float:
        return jnp.linalg.norm(state.player_location - BEACON_LOCS[state.active_beacon])

    def _beacon_bin_probs(self, scan_distance_mu: float) -> jax.Array:
        edges = jnp.linspace(0, jnp.sqrt(8.0), num=self.n_obs_bins + 1, dtype=jnp.float32)
        zU = (edges[1:] - scan_distance_mu) / self.sensor_noise
        zL = (edges[:-1] - scan_distance_mu) / self.sensor_noise
        p = jsp.stats.norm.cdf(zU) - jsp.stats.norm.cdf(zL)
        p = jnp.clip(p, 1e-12, 1.0)
        return p / jnp.sum(p)

    def _sample_beacon_bin(self, state: TriangulationState, key: jax.Array) -> int:
        mu = self._scan_distance_mu(state)
        p = self._beacon_bin_probs(mu)
        return jax.random.categorical(key, jnp.log(p))

    def _beacon_bin_loglikelihood(
        self, observation: TriangulationObservation, state: TriangulationState
    ) -> jax.Array:
        mu = self._scan_distance_mu(state)
        p = self._beacon_bin_probs(mu)
        return jnp.log(p[observation.distance])

    def _get_observation(
        self, state: TriangulationState, action: int, key: jax.Array
    ) -> TriangulationObservation:
        action_mask = jnp.ones(self.action_spec.num_values, dtype=jnp.int32)
        no_obs = TriangulationObservation(distance=-1, action_mask=action_mask)
        return jax.lax.cond(
            action == 5,
            lambda: TriangulationObservation(
                distance=self._sample_beacon_bin(state, key),
                action_mask=action_mask,
            ),
            lambda: no_obs,
        )

    def _create_next_state(
        self, state: TriangulationState, action: int, key: jax.Array
    ) -> tuple[TriangulationState, int]:
        dir_vec = DELTAS[action % 4]
        pos_update = dir_vec * self.step_size + jax.random.normal(key, (2,)) * self.movement_noise
        position = jax.lax.cond(
            action < 4,
            lambda: state.player_location + pos_update,
            lambda: state.player_location,
        )
        position = jnp.clip(position, START_BOX[0], START_BOX[1])
        # just cycle through beacons to determine the active one
        active_beacon = (state.active_beacon + 1) % BEACON_LOCS.shape[0]
        reward = jax.lax.cond(
            action == 4,  # stop
            lambda: STEP_REWARD + MAX_REWARD - jnp.linalg.norm(position - GOAL),
            lambda: STEP_REWARD,
        )
        reward = jax.lax.cond(state.terminal, lambda: 0.0, lambda: reward)
        terminal = action == 4
        return TriangulationState(
            player_location=position,
            active_beacon=active_beacon,
            step_count=state.step_count + 1,
            terminal=terminal,
        ), reward

    def step(
        self, state: TriangulationState, action: int, key: jax.Array
    ) -> tuple[TriangulationState, TimeStep[TriangulationObservation]]:
        obs_key, step_key = jax.random.split(key, 2)
        observation = self._get_observation(state, action, obs_key)
        next_state, reward = self._create_next_state(state, action, step_key)

        time_limit_exceeded = next_state.step_count >= self.time_limit
        player_stopped = next_state.terminal
        done = time_limit_exceeded | player_stopped
        timestep = jax.lax.cond(
            done,
            termination,
            transition,
            reward,
            observation,
        )
        return next_state, timestep

    def observation_weights(
        self,
        particles: TriangulationState,
        observations: TriangulationObservation,
        gold_observation: TriangulationObservation,
    ) -> jax.Array:
        scan_distances = jax.vmap(self._scan_distance_mu)(particles)
        obs_bin_probs = jax.vmap(self._beacon_bin_probs)(scan_distances)
        p = obs_bin_probs[:, gold_observation.distance]
        return jax.lax.cond(
            gold_observation.distance < 0,
            lambda: jnp.ones(particles.player_location.shape[0])
            / particles.player_location.shape[0],
            lambda: p / jnp.sum(p) + 1e-12,
        )

    def default_policy(
        self, state: TriangulationState, observation: TriangulationObservation, key: jax.Array
    ) -> int:
        return 4  # stop where we are


def build_triangulation_env(
    step_size: float = 0.5,
    movement_noise: float = 0.05,
    sensor_noise: float = 0.5,
    n_obs_bins=20,
    time_limit: int = 100,
) -> Triangulation:
    return Triangulation(
        step_size,
        movement_noise,
        sensor_noise,
        n_obs_bins,
        time_limit,
    )
