import os
import pickle
from copy import deepcopy

import numpy as np
from scipy.stats import truncexpon
from tqdm import tqdm

from comab.algo.algo_factory import build_algo
from comab.environment import CoMABEnv

# Creates results directory if not existing
os.makedirs("exp_3_results", exist_ok=True)

# Parameters for the environment
T = 500
D = np.floor(np.sqrt(T)).astype(int)
N = 5
K = 1
p = np.arange(K, dtype=int) + 2
P = max(p)
F = [truncexpon(b=1) for k in range(K)]
R = max(_F.support()[1] for _F in F)
F_name = "truncexpon"
# Create the environment
env = CoMABEnv(p, F, N)

# List of algos to bench
algos = [
    "ucb1",
    "exp3",
    "osub",
    "local_greedy",
    "greedy_grid"
]

# Number of simulations per algorithm
num_simulations = 20

for algo_name in tqdm(algos, desc="algorithm", position=0, leave=False):
    all_allocations = []
    all_expected_rewards = []

    # Run multiple simulations
    for _ in tqdm(range(num_simulations), desc="simulations", position=1, leave=False):
        algo = build_algo(algo_name, K, N,T, p, D=D, R=R, delta=1. / np.sqrt(np.arange(T)+1), c=1., alpha=0.2)

        n = algo.n
        allocations = []
        expected_rewards = []

        for t in tqdm(range(T), desc="time steps", position=2, leave=False):
            arms_with_observation, observed_gains, observed_costs = env.step(n)
            algo.update(arms_with_observation, observed_gains, observed_costs, t)
            n = algo.n
            allocations.append(deepcopy(n))
            expected_rewards.append(env.r(n))

        all_allocations.append(allocations)
        all_expected_rewards.append(expected_rewards)

    with open("exp_3_results/" + algo_name + "_K=" + str(K) + "_N=" + str(N) + "_F=" + F_name + '.pickle', 'wb') as f:
        pickle.dump(
            {
                "algo_name": algo_name,
                "n_star": env.n_star,
                "allocations": np.array(all_allocations).astype(int),
                "expected_regret": (env.r_star - np.array(all_expected_rewards))
            },
            f, pickle.HIGHEST_PROTOCOL)
