import numpy as np
from bandit.mab import *
from bandit.best_reward import *
import pickle
from pathlib import Path
import time
from scipy.stats import norm

def compute_sigma(N, target_min=0.1, target_max=0.9):
    u_min = 0.5 / N
    u_max = (N - 0.5) / N
    z_min = norm.ppf(u_min)
    z_max = norm.ppf(u_max)
    sigma = (target_max - target_min) / (z_max - z_min)
    return sigma

def gaussian_points(N):
    mu = 0.5
    sigma = compute_sigma(N)
    u = (np.arange(1, N + 1) - 0.5) / N
    x = norm.ppf(u, loc=mu, scale=sigma)
    return x

def get_means(instance, n_arms):
    if instance == "equal":
        means = (np.arange(0, n_arms) + 0.5) / n_arms
        means = sorted(means, reverse=True)
    elif instance == "single":
        means = np.full(n_arms, 0.5)
    elif instance == "bell":
        means = gaussian_points(n_arms)
    else:
        print(f"Unknown instance: {instance}")
        exit()
    return means

def get_algorithm(algorithm, epsilon, delta, R=0.5, S=0.5, reg=None):
    if algorithm == "SE-arm":
        alg = SuccessiveElimination(epsilon, delta)
    elif algorithm == "SE":
        alg = SuccessiveEliminationBME(epsilon, delta)
    elif algorithm == "ellipsoid":
        alg = EllipsoidBME(epsilon, delta, R=R, S=S, reg=reg)
    elif algorithm == "ellipsoid2":
        alg = EllipsoidBME(epsilon, delta, R=R, S=S, reg=reg, tight=True)
    elif algorithm == "ME":
        alg = MedianEliminationBAI2BME(epsilon, delta)
    elif algorithm == "UGapEc":
        alg = UGapEc2BME(epsilon, delta)
    else:
        print(f"Unknown algorithm: {algorithm}")
        exit()
    return alg

def get_filename(algorithm, instance, epsilon, delta, n_arms, seed, distribution="Bernoulli", reg=None, R=0.5, S=0.5):
    filename = f"result_{algorithm}_{instance}_{epsilon}_{delta}_{n_arms}_{seed}"
    if distribution != "Bernoulli":
        filename += f"_{distribution}"
    if algorithm in ["ellipsoid", "ellipsoid2"]:
        filename += f"_{reg}_{R}_{S}"
    filename += ".pkl"
    return filename

def run(algorithm, instance, epsilon, delta, n_arms, seed, distribution="Bernoulli", reg=None, R=0.5, S=0.5, rerun=False):

    path = Path("results")
    path.mkdir(exist_ok=True)
    filename = get_filename(algorithm, instance, epsilon, delta, n_arms, seed, distribution=distribution, reg=reg, R=R, S=S)
    path = path.joinpath(filename)

    if path.exists() and not rerun:
        with open(path, mode='br') as f:
            result = pickle.load(f)
        return result

    rng = np.random.default_rng(seed)

    # prepare environment
    means = get_means(instance, n_arms)
    if algorithm in ["ellipsoid", "ellipsoid2"]:
        # shift [-0.5, 0.5] to make S=0.5
        means = np.array(means) - 0.5
    if distribution == "Bernoulli":
        env = BernoulliMAB(means, rng=rng)
    elif distribution == "Gaussian":
        stds = [0.5 for _ in means]
        env = GaussianMAB(means, stds, rng=rng)

    alg = get_algorithm(algorithm, epsilon, delta, R=R, S=S, reg=reg)
    start_time = time.perf_counter()
    best_mean = alg.run(env)
    end_time = time.perf_counter()
    max_mu = alg.get_max_mu()
    if algorithm in ["ellipsoid", "ellipsoid2"]:
        # shift [-0.5, 0.5] to make S=0.5
        best_mean += 0.5
        max_mu += 0.5
    result = {
        "best_mean": best_mean,
        "max_mu": max_mu,
        "n_samples": alg.get_total_n_sample(),
        "time": end_time - start_time,
    }
    print(result)

    # Save results
    with open(path, mode='wb') as f:
      pickle.dump(result, f)

    return result
