from functools import partial
from typing import Optional, Union

import jax
import jax.numpy as jnp
import numpy as np

# Default to CPU, faster to compute
jax.config.update("jax_platform_name", "cpu")


class Oscillators:
    """
    A class to represent oscillators with phase dependent frequencies.

    :param omega_swing: Pulsation for swing phase
    :param omega_stance: Pulsation for stance phase
    :param time_step: time step for integration
    :param theta_init: Initial phase
    :param phase_shifts: The phase shifts between each oscillator
    :param n_dim: Number of oscillators
    """

    def __init__(
        self,
        omega_swing: Union[float, np.ndarray] = 1 * 2 * np.pi,
        omega_stance: Union[float, np.ndarray] = 1 * 2 * np.pi,
        time_step: float = 0.001,
        theta_init: Optional[np.ndarray] = None,
        phase_shifts: Optional[np.ndarray] = None,
        n_dim: int = 4,
    ):
        self.n_dim = n_dim

        if isinstance(omega_swing, float):
            self._omega_swing = np.ones(n_dim) * omega_swing
            self._omega_stance = np.ones(n_dim) * omega_stance
        else:
            assert isinstance(omega_stance, np.ndarray)
            self._omega_swing = omega_swing
            self._omega_stance = omega_stance
        self._dt = time_step

        self.phase_shifts = np.zeros(n_dim)
        # Use pre-defined coupling matrix
        if phase_shifts is not None:
            self.phase_shifts = phase_shifts

        # Set oscillator initial conditions
        if theta_init is not None:
            self.theta = theta_init
        else:
            self.theta = self.phase_shifts.copy()

    @staticmethod
    def integrate_one_step(
        theta: jnp.ndarray,
        omega_swing: jnp.ndarray,
        omega_stance: jnp.ndarray,
        dt: float,
    ) -> jnp.ndarray:
        in_swing_phase = jnp.sin(theta) > 0
        theta_dot = in_swing_phase * omega_swing + (1 - in_swing_phase) * omega_stance
        # Integrate and keep theta in [0, 2 * pi]
        return (theta + dt * theta_dot) % (2 * np.pi)

    @staticmethod
    @partial(jax.jit, static_argnames=["n_updates"])
    def n_updates(
        n_updates: int,
        theta: jnp.ndarray,
        omega_swing: jnp.ndarray,
        omega_stance: jnp.ndarray,
        dt: float,
    ):
        for _ in range(n_updates):
            theta = Oscillators.integrate_one_step(theta, omega_swing, omega_stance, dt)
        return theta

    def update(self, n_updates: int = 1) -> np.ndarray:
        """Update oscillator states."""
        # Fast Jax JIT version
        self.theta = Oscillators.n_updates(n_updates, self.theta, self._omega_swing, self._omega_stance, self._dt)
        return self.theta
        # Numpy version
        # in_swing_phase = np.sin(self.theta) > 0
        # # Integrate
        # self.theta[in_swing_phase] += self._dt * self._omega_swing[in_swing_phase]
        # self.theta[~in_swing_phase] += self._dt * self._omega_stance[~in_swing_phase]
        # # keep theta in [0, 2 * pi]
        # self.theta = self.theta % (2 * np.pi)
        #
        # return self.theta
