import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import itertools

from multiprocessing import Pool
from bandit import Bandit
from bandit import CTB_Bandit

from algorithms.toptwo_ucb import TTUCB
from algorithms.adapTT import AdaPTopTwo
from algorithms.dp_se import DPSE
from algorithms.imp_adapTT import ImprovedAdaPTopTwo
from algorithms.dp_tt import DPTT

from utils import read_results




def run_simulation(config): 
    #instantiate bandit 
    mu = config["mu"]
    eps = config["eps"]
    
    #instantiate algorithm
    algos = {"tt_ucb": TTUCB, "adap_tt": AdaPTopTwo, "dp_se": DPSE, "imp_ada": ImprovedAdaPTopTwo, "dp_tt": DPTT}
    algo_name = config["algo_name"] 
    algorithm = algos[algo_name](config)
    
    #run algorithm on bandit instance
    if config["algo_name"] == "ctb_tt":  #for local DP, we create a CTB_Bandit instance
        ctb_bandit = CTB_Bandit(eps, mu)
        algorithm.run(ctb_bandit)
    else:
        bandit = Bandit(mu)
        algorithm.run(bandit)

if __name__ == "__main__":
    
    #hyperparameters
    Deltas = [1e-2]
    #for global DP, epsilon from 0.001 to 1000.
    Epsilons = 10**(np.linspace(-3, 2, 10)) # 10**(np.linspace(-1.5, 2, 10))
    combinations = list(itertools.product(Epsilons, Deltas))
    num_simulations = 1000 #100 to run faster
    num_workers = 64

    #time
    time = 13_05_1011

    #means of the bandit instance
    mu0 = np.array([0.1, 0.3, 0.5, 0.7, 0.9]) #env0
    mu1 = np.array([0.5, 0.9, 0.9, 0.9, 0.95]) #env1
    mu2 = np.array([0.75,0.625,0.5,0.375,0.25]) #env2
    mu3 = np.array([0.75,0.70,0.70,0.70,0.70]) #env3
    mu4 = np.array([0.75,0.53125,0.375,0.28125,0.25]) #env4
    mu5 = np.array([0.75,0.71875,0.625,0.46875,0.25]) #env5
    
    list_mus = np.array([mu0, mu1, mu2, mu3, mu4, mu5])
    
    
    for i in range(len(list_mus)):
        mu = list_mus[i]
        K = len(mu)
        #Experiment name
        exp_id = f"env{i}_{num_simulations}simulations_{num_workers}workers"
        exp_name = f"Experiment_{exp_id}_{time}"
        print(f"###########Starting the experiments for env{i} with means {mu}#######")
        for (epsilon, delta) in combinations:
            for alg in ["tt_ucb", "adap_tt","dp_se", "imp_ada", "dp_tt"]:
                config = {"exp_name": exp_name, "mu": mu, "K": K, "algo_name": alg,\
                    "beta": 0.5,"eps": epsilon, "delta": delta}
                
                pool = Pool(processes = num_workers) # create num_workers processes
                results = pool.map(run_simulation, [(config)]*num_simulations)
                print(f"experiment done for {alg} and epsilon={epsilon}")
        
        print(f"###########The experiment has ended for env{i} with means {mu}#######")
        
        subfolders = ["/TTUCB", "/AdaPTT","/DPSE", "/ImpAdaPTT", "/DPTT"]
        data = []
        names = []
        fig, ax = plt.subplots()
        for sub in subfolders:
            folder = "./experiments/" + exp_name + sub
            name, epsilons, taus, astars = read_results(folder)
            print(sub)
            X = epsilons
            print("epsilons:", epsilons)
            Y_mean = np.mean(taus, axis = 1)
            Y_std = np.std(taus, axis = 1)
            print("mean:", Y_mean)
            print("std:", Y_std)
            ax.plot(X, Y_mean, label=name)
            ax.fill_between(X, Y_mean - Y_std, Y_mean + Y_std,
                            alpha=0.2)
        
        # #Set the y-axis label to scientific notation
        ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
        ax.yaxis.offsetText.set_visible(True)
        #ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e'))
        ax.legend(loc='best')
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlabel('epsilons')
        ax.set_title('log-sample complexity')
        plt.savefig(f"./experiment_{exp_name}.pdf")
        plt.show()
    
