import numpy as np
from numba import njit

N_TRIALS = 200
DT = 1e-3


@njit
def diffusion_trial_m1a(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m1a."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # visual encoding
    tau_encoding = np.random.normal(mu_tau_e, varsigma)
    # N200 latency
    z = np.random.normal(tau_encoding, sigma)
    rt = n_steps * dt + tau_m + tau_encoding

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m1a(params, n_trials=N_TRIALS):
    """Simulates a diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m1a(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma)
    return data


@njit
def diffusion_trial_m1b(drift, boundary, beta, mu_tau_e, mu_tau_m, sigma, varsigma, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m1b."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # visual encoding
    ndt = np.random.normal(mu_tau_e + mu_tau_m, varsigma)
    # N200 latency
    z = np.random.normal(ndt - mu_tau_m, sigma)
    rt = n_steps * dt + ndt

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m1b(params, n_trials=N_TRIALS):
    """Simulates a diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, mu_tau_m, sigma, varsigma = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m1b(drift, boundary, beta, mu_tau_e, mu_tau_m, sigma, varsigma)
    return data


@njit
def diffusion_trial_m1c(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m1b."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    z = 0
    while True:
        # visual encoding
        tau_encoding = mu_tau_e + np.random.uniform(-0.5 * np.sqrt(12) * varsigma, 0.5 * np.sqrt(12) * varsigma)
        z = np.random.normal(tau_encoding, sigma)
        if z > 0 and z < 0.5:
            break

    rt = n_steps * dt + tau_encoding + tau_m

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m1c(params, n_trials=N_TRIALS):
    """Simulates a diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m1c(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma)
    return data


@njit
def diffusion_trial_m2(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, gamma, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m2."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while (evidence > 0 and evidence < boundary):
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # visual encoding
    tau_encoding = np.random.normal(mu_tau_e, varsigma)
    # N200 latency
    z = np.random.normal(gamma * tau_encoding, sigma)
    rt = n_steps * dt + tau_m + tau_encoding

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m2(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, gamma = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m2(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, gamma)
    return data


@njit
def diffusion_trial_m3(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, theta, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m3."""

    rng = np.random.uniform(0, 1)
    # guesssing process
    if rng > 1 - theta:
        rt = np.random.uniform(0, 5)
        choice = np.random.binomial(1, 0.5)
        # visual encoding
        tau_encoding = np.random.normal(mu_tau_e, varsigma)
        # N200 latency
        z = np.random.normal(tau_encoding, sigma)
        return rt, choice, z

    # ddm process
    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # visual encoding
    tau_encoding = np.random.normal(mu_tau_e, varsigma)
    # N200 latency
    z = np.random.normal(tau_encoding, sigma)
    rt = n_steps * dt + tau_m + tau_encoding

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m3(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, theta = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m3(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, theta)
    return data


@njit
def diffusion_trial_m4a(drift, boundary, beta, mu_tau_e, tau_m, tau, sigma_e, sigma_k, varsigma, k, theta, s=1.0,
                        dt=DT):
    """Simulates a trial from the joint diffusion model m4a."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    rng = np.random.uniform(0, 1)

    if rng <= 1 - theta:
        # visual encoding
        tau_encoding = np.random.normal(mu_tau_e, varsigma)
        # N200 latency
        z = np.random.normal(tau_encoding, sigma_e)
        rt = n_steps * dt + tau_m + tau_encoding
        if evidence >= boundary:
            return (rt, 1.0, z)
        return (rt, 0.0, z)

    # N200 latency
    z = np.random.normal(k, sigma_k)
    rt = n_steps * dt + tau
    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m4a(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, tau, sigma_e, sigma_k, varsigma, k, theta = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m4a(drift, boundary, beta, mu_tau_e, tau_m, tau, sigma_e, sigma_k, varsigma, k, theta)
    return data


@njit
def diffusion_trial_m4b(drift, boundary, beta, mu_tau_e, tau_m, sigma_e, varsigma, theta, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m4a."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    rng = np.random.uniform(0, 1)

    if rng <= 1 - theta:
        # visual encoding
        tau_encoding = np.random.normal(mu_tau_e, varsigma)
        # N200 latency
        z = np.random.normal(tau_encoding, sigma_e)
        rt = n_steps * dt + tau_m + tau_encoding
        if evidence >= boundary:
            return (rt, 1.0, z)
        return (rt, 0.0, z)

    # N200 latency
    z = np.random.normal(mu_tau_e, np.sqrt(sigma_e ** 2 + varsigma ** 2))
    rt = n_steps * dt + tau_m + mu_tau_e
    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m4b(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma_e, varsigma, theta = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m4b(drift, boundary, beta, mu_tau_e, tau_m, sigma_e, varsigma, theta)
    return data


@njit
def diffusion_trial_m5(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, a_slope, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m5."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > a_slope * n_steps * dt and evidence < (boundary - a_slope * n_steps * dt):
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # visual encoding
    tau_encoding = np.random.normal(mu_tau_e, varsigma)
    # N200 latency
    z = np.random.normal(tau_encoding, sigma)
    rt = n_steps * dt + tau_m + tau_encoding
    if evidence >= boundary - a_slope * n_steps * dt:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m5(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, a_slope = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m5(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, a_slope)
    return data


@njit
def diffusion_trial_m6(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, lam, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m6."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta
    k = 3
    delt = -1

    while evidence > (1 - np.exp(-(n_steps * dt / lam) ** k)) * (-0.5 * delt * boundary) and evidence < (
            boundary - (1 - np.exp(-(n_steps * dt / lam) ** k)) * (-0.5 * delt * boundary)):
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # visual encoding
    tau_encoding = np.random.normal(mu_tau_e, varsigma)
    # N200 latency
    z = np.random.normal(tau_encoding, sigma)
    rt = n_steps * dt + tau_m + tau_encoding
    if evidence >= (boundary - (1 - np.exp(-(n_steps * dt / lam) ** k)) * (-0.5 * delt * boundary)):
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m6(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, lam = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m6(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, lam)
    return data


@njit
def diffusion_trial_m7(mu_drift, boundary, beta, tau, sigma, eta, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m7."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta
    drift_t = mu_drift + eta * np.random.normal()

    while evidence > 0 and evidence < boundary:
        evidence += drift_t * dt + c * np.random.normal()
        n_steps += 1.0

    # cpp
    cpp = np.random.normal(drift_t, sigma)
    rt = n_steps * dt + tau

    if evidence >= boundary:
        return (rt, 1.0, cpp)
    return (rt, 0.0, cpp)


@njit
def simulate_trials_m7(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    mu_drift, boundary, beta, tau, sigma, eta = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m7(mu_drift, boundary, beta, tau, sigma, eta)
    return data


@njit
def diffusion_trial_m8(mu_drift, boundary, tau, sigma, gamma, eta, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m8."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * 0.5
    drift_t = mu_drift + eta * np.random.normal()

    while (evidence > 0 and evidence < boundary):
        evidence += drift_t * dt + c * np.random.normal()
        n_steps += 1.0

    # cpp
    cpp = np.random.normal(gamma * drift_t, sigma)
    rt = n_steps * dt + tau

    if evidence >= boundary:
        return (rt, 1.0, cpp)
    return (rt, 0.0, cpp)


@njit
def simulate_trials_m8(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    mu_drift, boundary, tau, sigma, gamma, eta = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m8(mu_drift, boundary, tau, sigma, gamma, eta)
    return data


@njit
def diffusion_trial_m9(drift, boundary, beta, t_e, t_m, sigma_e, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m9."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # N200 latency
    z = np.random.normal(t_e, sigma_e)
    rt = n_steps * dt + t_e + t_m

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m9(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, t_e, t_m, sigma_e = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m9(drift, boundary, beta, t_e, t_m, sigma_e)
    return data


@njit
def diffusion_trial_m10(drift, boundary, beta, t_e, t_m, sigma_e, gamma, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m10."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while (evidence > 0 and evidence < boundary):
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # N200 latency
    z = np.random.normal(gamma * t_e, sigma_e)
    rt = n_steps * dt + t_e + t_m

    if evidence >= boundary:
        return (rt, 1.0, z)
    return (rt, 0.0, z)


@njit
def simulate_trials_m10(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, t_e, t_m, sigma_e, gamma = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m10(drift, boundary, beta, t_e, t_m, sigma_e, gamma)
    return data


@njit
def diffusion_trial_m11(drift, boundary, beta, ndt, eta, s=1.0, dt=DT):
    """Simulates a trial from the joint diffusion model m11."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta
    drift_t = drift + eta * np.random.normal()

    while evidence > 0 and evidence < boundary:
        evidence += drift_t * dt + c * np.random.normal()
        n_steps += 1.0

    # cpp
    cpp = np.random.normal(drift, eta)
    rt = n_steps * dt + ndt

    if evidence >= boundary:
        return (rt, 1.0, cpp)
    return (rt, 0.0, cpp)


@njit
def simulate_trials_m11(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, ndt, eta = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m11(drift, boundary, beta, ndt, eta)
    return data


@njit
def diffusion_trial_m12(drift, boundary, beta, ndt, sigma, gamma, s=1.0, dt=DT, max_steps=2e4):
    """Simulates a trial from the joint diffusion model m12."""

    c = np.sqrt(dt) * s
    n_steps = 0.0
    evidence = boundary * beta

    while evidence > 0 and evidence < boundary and n_steps < max_steps:
        evidence += drift * dt + c * np.random.normal()
        n_steps += 1.0

    # cpp
    cpp = np.random.normal(gamma * drift, sigma)
    rt = n_steps * dt + ndt

    if evidence >= boundary:
        return (rt, 1.0, cpp)
    elif evidence <= 0:
        return (rt, 0.0, cpp)
    else:
        if np.sign(evidence - boundary * 0.5) == 1:
            return (rt, 1.0, cpp)
        else:
            return (rt, 0.0, cpp)


@njit
def simulate_trials_m12(params, n_trials=N_TRIALS):
    """Simulates a joint diffusion process for trials."""

    drift, boundary, beta, ndt, sigma, gamma = params
    data = np.empty((n_trials, 3))
    for i in range(n_trials):
        data[i] = diffusion_trial_m12(drift, boundary, beta, ndt, sigma, gamma)
    return data
