from policies import *

# ------------------------------------------------------------------------------#
# Main experiment
# ------------------------------------------------------------------------------#

# Parameters
n_steps = 10_000
n_trials = 10_000

# Prior hyperparameters
mu, var = 0.0, 1e3

# environment = GAUSSIAN
environment = LAPLACE
postfix = "ready"

policies = [
    "TS-0.9",
    "TS-1.0",
    "TS-1.1",
    "TS-VI-0.3",
    "TS-VI-0.4",
    "TS-VI-0.5",
    "UCB-0.9",
    "UCB-1.0",
    "UCB-1.1",
]

# Pre-allocate storage
all_rewards = {p: np.zeros((n_trials, n_steps), dtype=np.float64) for p in policies}
all_pulls = {p: np.zeros((n_trials, n_steps), dtype=np.int32) for p in policies}

print("Starting warm-up ...")

# Warm up the JIT compilers on a dummy run
_dummy_pulls = np.empty(10, np.int32)
_dummy_rewards = np.empty(10, np.float64)
run_UCB_jit(
    0.0,
    1e3,
    np.array([0.0, 0.0]),
    np.array([1.0, 1.0]),
    10,
    _dummy_pulls,
    _dummy_rewards,
    0.5,
    environment,
)

_dummy_pulls = np.empty(n_steps, np.int32)
_dummy_rewards = np.empty(n_steps, np.float64)
run_TS_jit(
    mu,
    var,
    np.array([0.0, 0.0]),
    np.array([1.0, 1.0]),
    10,
    _dummy_pulls,
    _dummy_rewards,
    0.9,
    environment,
)

_dummy_pulls = np.empty(n_steps, np.int32)
_dummy_rewards = np.empty(n_steps, np.float64)
run_TSVI_jit(
    mu,
    var,
    np.array([0.0, 0.0]),
    np.array([1.0, 1.0]),
    10,
    _dummy_pulls,
    _dummy_rewards,
    0.3,
    environment,
)

print("Warm-up finished.")

# Run the full grid
for delta in [0.3, 0.5]:
    print(f"delta = {delta}")
    r_loc = np.array([-delta, delta], dtype=np.float64)
    r_scale = np.array([1.0, 1.0], dtype=np.float64) * np.sqrt(2 - environment)

    # r_loc = np.array(
    #     [-3 * delta, -2 * delta, -1 * delta, +1 * delta, +2 * delta, +3 * delta],
    #     dtype=np.float64,
    # )
    # r_scale = np.array([1.0, 1.0, 1.0, 2.0, 2.0, 2.0], dtype=np.float64) * np.sqrt(
    #     2 - environment
    # )

    n_arms = len(r_loc)

    for p in policies:
        policy = p.split("-")
        eta = float(policy[-1])
        print(policy, eta)
        for trial in tqdm(range(n_trials)):
            pulls = all_pulls[p][trial]
            rewards = all_rewards[p][trial]

            if policy[0] == "TS":
                if policy[1] == "VI":
                    run_TSVI_jit(
                        mu,
                        var,
                        r_loc,
                        r_scale,
                        n_steps,
                        pulls,
                        rewards,
                        eta,
                        environment,
                    )
                else:
                    run_TS_jit(
                        mu,
                        var,
                        r_loc,
                        r_scale,
                        n_steps,
                        pulls,
                        rewards,
                        eta,
                        environment,
                    )
            elif policy[0] == "UCB":
                run_UCB_jit(
                    mu, var, r_loc, r_scale, n_steps, pulls, rewards, eta, environment
                )

    # save per-delta
    np.save(
        f"results/{dict_TYPE[environment]}-{n_arms}-{delta}-{postfix}.npy",
        {"rewards": all_rewards, "pulls": all_pulls},
    )
