from numpy import ndarray as Array
import numpy as np
import tqdm


def step(
    x: Array,
    sigma: float = 0.25,
    dt: float = 1e-3,
    s: float = 10.0,
    r: float = 28.0,
    b: float = 8.0 / 3.0,
) -> Array:
    """
    Propagate the state from time t to (t+dt) using the Milstein scheme.
    Input(s):
        - x (Array): states of the system at time t with shape (batch_size, 3).
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - dt (float): time step.
    """
    # unpack
    x1 = x[:, 0]
    x2 = x[:, 1]
    x3 = x[:, 2]

    # Drift terms
    f1 = s * (x2 - x1)
    f2 = r * x1 - x2 - x1 * x3
    f3 = x1 * x2 - b * x3

    # Wiener increments
    dW = np.sqrt(dt) * np.random.randn(*x.shape)
    dW1 = dW[:, 0]
    dW2 = dW[:, 1]
    dW3 = dW[:, 2]

    # Milstein updates
    x1_next = x1 + f1 * dt + sigma * x1 * dW1 + 0.5 * sigma**2 * x1 * (dW1**2 - dt)
    x2_next = x2 + f2 * dt + sigma * x2 * dW2 + 0.5 * sigma**2 * x2 * (dW2**2 - dt)
    x3_next = x3 + f3 * dt + sigma * x3 * dW3 + 0.5 * sigma**2 * x3 * (dW3**2 - dt)

    return np.stack([x1_next, x2_next, x3_next], axis=1)


def integrate(
    x_0: Array,
    num_steps: int,
    dt: float,
    sigma: float,
    keep: bool = False,
    verbose: bool = False,
) -> Array:
    """
    Integrate the SDE for num_steps with time step dt.
    Input(s);
        - x_0 (Array): Initial condition with dimension (batch_size, 3).
        - num_steps (int): number of steps to do.
        - dt (float): time step for the integration.
        - sigma (float): noise level of the stochastic term in the stochastic L63 Equation.
        - keep (bool): if True, it returns the complete trajectory with dimension (num_steps+1, batch_size, 3).
        - verbose (bool): if True, use a tqdm progress bar.
    """
    x = x_0.copy()
    if keep:
        batch_size = x.shape[0]
        traj = np.zeros((num_steps + 1, batch_size, 3))
        traj[0] = x_0
    if verbose:
        iterator = tqdm.tqdm(range(1, num_steps + 1))
    else:
        iterator = range(1, num_steps + 1)
    for n in iterator:
        x = step(x=x, dt=dt, sigma=sigma)
        if keep:
            traj[n] = x  # type: ignore
    if keep:
        return traj  # type: ignore
    else:
        return x
