import timeit

import pytest
import jax.numpy as jnp
from jax import random
from jax import jit

import fairgym.envs.action
import fairgym.envs.state

key = random.PRNGKey(758493)


pr_g = jnp.array([0.3, 0.4, 0.3])
pr_x = random.uniform(key, shape=(3, 1000))
pr_x /= jnp.sum(pr_x, axis=1)[:, None]
pr_y1Gx = random.uniform(key, shape=(3, 1000))
pr_y1Gx /= jnp.sum(pr_y1Gx, axis=1)[:, None]

test_state = fairgym.envs.state.create_state(pr_g, pr_x, pr_y1Gx)
test_action = jnp.array([0.2, 0.5, 0.9])


@pytest.mark.parametrize(
    "fns,args,kwargs,number",
    (
        (
            (
                fairgym.envs.state.create_state,
                jit(fairgym.envs.state.create_state),
            ),
            (pr_g, pr_x, pr_y1Gx),
            {},
            1e3,
        ),
        (
            (
                fairgym.envs.action.threshold_action,
                jit(fairgym.envs.action.threshold_action),
            ),
            (test_state, test_action),
            {},
            1e3,
        ),
        (
            (fairgym.envs.action._take, jit(fairgym.envs.action._take)),
            (
                test_state.pr_Y0alX,
                test_action,
            ),
            {},
            1e3,
        ),
    ),
)
def test_jit_speedup(fns, args, kwargs, number):
    """
    Tests if jit speedups are implemented correctly and/or are worth it.
    :param fns
    :param args:
    :param kwargs:
    :param number:
    :return:
    """
    times = tuple(
        timeit.Timer(lambda: fn(*args, **kwargs)).timeit(number=int(number))
        for fn in fns
    )
    print("\n" + "\n".join(f"{t:>8.2f}s  {repr(fn)})" for t, fn in zip(times, fns)))
    if len(fns) == 2:
        # Assumes fn[0] is base
        speedup_tol = -0.1
        speedup = (times[0] - times[1]) / times[0]
        assert speedup > speedup_tol
