from bayesrl.environments import ChainWorld
import timeit
import numpy as np


def check_value_equal():
    discount_factor = 0.995
    #state_sizes = [50, 100, 200]
    state_sizes = [50]

    for num_states in state_sizes:
        task = ChainWorld(num_states, num_states - 1, num_states, 2*num_states - 1, -1, 0)
        q1, v1 = task.solve_optimal_value_function_numpy(discount_factor)
        q2, v2 = task.solve_optimal_value_function_jax(discount_factor)
        max_err_q, mean_err_q = np.max(np.abs(q1-q2)), np.mean(np.abs(q1-q2))
        max_err_v, mean_err_v = np.max(np.abs(v1-v2)), np.mean(np.abs(v1-v2))
        print(f"N = {num_states}... \nmax_err_q {max_err_q} | max_err_v {max_err_v}\nmean_err_q {mean_err_q} | mean_err_v {mean_err_v}\n")
        print(q1)
        print(q2)

        print("\n", v1)
        print(v2)

def profile_with_timeit():
    discount_factor = 0.995
    state_sizes = [100, 500, 1000]

    print("Profiling NumPy Implementation...")
    for num_states in state_sizes:
        task = ChainWorld(num_states, num_states - 1, num_states, 2*num_states - 1, -1, 0)
        elapsed_time = timeit.timeit(
            lambda: task.solve_optimal_value_function_numpy(discount_factor),
            number=3,  # Run 3 times and take the average
        )
        print(f"NumPy: num_states={num_states}, avg_time={elapsed_time / 3:.4f} seconds")

    print("\nProfiling JAX Implementation...")
    for num_states in state_sizes:
        task = ChainWorld(num_states, num_states - 1, num_states, 2*num_states - 1, -1, 0)
        elapsed_time = timeit.timeit(
            lambda: task.solve_optimal_value_function_jax(discount_factor),
            number=3,  # Run 3 times and take the average
        )
        print(f"JAX: num_states={num_states}, avg_time={elapsed_time / 3:.4f} seconds")

profile_with_timeit()
#check_value_equal()
