from Environment import *
from Algorithms import *
import numpy as np
import pickle
from tqdm import tqdm
import multiprocessing
from pathlib import Path


def run(T,repeat,d,N,K,L,i):
    exp_reward=dict()
    avg_regret_sum=dict()
    oracle_reward=0
    regret_sum_list=dict()
    std=dict()
    index=dict()
    Env_dict=dict()
    S=dict()
    for algorithm in ['Elimination','UCB-GS','ETC-GS']:
        exp_reward[algorithm]=[]    
        regret_sum_list[algorithm]=np.zeros((repeat,T),float)
        avg_regret_sum[algorithm]=np.zeros(T,float)
        std[algorithm]=np.zeros(T,float)
    print('repeat',i)
    seed=i
    Env=linear_Env(seed,d,N,K,L)
    delta=Env.delta()
    alg_UCB_GS=UCB_GS(seed,Env.x, N, K, L, T)
    alg_Elim=Elimination(seed,Env.x, N, K, L, T)
    alg_Etc_GS=ETC_GS(seed,Env.x, N, K, L, T,delta)

    oracle_reward,_=Env.oracle()

    algorithms=[alg_Elim, alg_Etc_GS, alg_UCB_GS]

    for algorithm in algorithms:
        name=algorithm.name()
        exp_reward[name]=[]       
        Env_dict[name]=linear_Env(seed,d,N,K,L)
    for algorithm in algorithms:
        algorithm.reset()
        name=algorithm.name()
        print(name)
        for t in tqdm((np.array(range(T))+1)):
            if t==1:
                algorithm.run(t,np.zeros(K))
            else:
                algorithm.run(t,index[name])
            S=algorithm.offer()

            index[name]=Env_dict[name].observe(S)
            exp_reward[name].append(Env_dict[name].exp_reward(S))


    for algorithm in algorithms:
        name=algorithm.name()
        reg=oracle_reward-exp_reward[name]
        regret_sum=np.cumsum(reg)
        regret_sum_list[name][i,:]=regret_sum
        avg_regret_sum[name]+=regret_sum  


        filename_1=name+'T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(i)+'regret.txt'
        with open('./result/'+filename_1, 'wb') as f:
            pickle.dump(regret_sum, f)
            f.close()
            
    exp_reward.clear()
    avg_regret_sum.clear()
    regret_sum_list.clear()
    std.clear()
    index.clear()
    Env_dict.clear()
    S.clear()


def run_multiprocessing(T, repeat, d, N, K, L):
    Path("./result").mkdir(parents=True, exist_ok=True)

    num_processes = multiprocessing.cpu_count()
    with multiprocessing.Pool(processes=num_processes) as pool:
        pool.starmap(run, [( T, repeat, d, N, K, L, i) for i in range(repeat)])
    pool.close()
    pool.join()
    

if __name__=='__main__':
    L=2  
    T=1000
    d=3
    repeat=10
    for N in [4,5]:
        for K in [2,3,4]:
            run_multiprocessing(T,repeat,d,N,K,L)
    


