from Environment import *
from Algorithms import *
import numpy as np
import pickle
from tqdm import tqdm
import multiprocessing
from pathlib import Path
import timeit


def run(T,repeat,d,N,K,L,i):
    exp_reward=dict()
    avg_regret_sum=dict()
    oracle_reward=0
    regret_sum_list=dict()
    std=dict()
    index=dict()
    Env_dict=dict()
    time=dict()
    S=dict()
    for algorithm in ['UCB','Elimination','Elimination2','UCB-QMB', 'TS-QMB']:
        exp_reward[algorithm]=[]    
        regret_sum_list[algorithm]=np.zeros((repeat,T),float)
        avg_regret_sum[algorithm]=np.zeros(T,float)
        std[algorithm]=np.zeros(T,float)
        time[algorithm]=np.zeros(1)
    print('repeat',i)
    seed=i
    Env=linear_Env(seed,d,N,K,L)
    delta=Env.delta()
    rev=Env.rev
    alg_QMB=UCB_QMB(seed,Env.x, N, K, L, T)
    alg_TSQMB=TS_QMB(seed,Env.x, N, K, L, T)

    alg_UCB=UCB(seed,Env.x, N, K, L, T,rev)
    # alg_UCB_GS=UCB_GS(seed,Env.x, N, K, L, T,rev)
    alg_Elim=Elimination(seed,Env.x, N, K, L, T, rev)
    alg_Elim2=Elimination2(seed,Env.x, N, K, L, T, rev)
    # alg_Etc_GS=ETC_GS(seed,Env.x, N, K, L, T,delta,rev)

    oracle_reward,_=Env.oracle()
    # print(Env.x)
    # print(Env.rev)
    # algorithms=[alg_UCB,alg_Elim2,alg_Elim, alg_Etc_GS, alg_UCB_GS,alg_QMB]
    # algorithms=[alg_Elim,alg_Elim2,alg_UCB,alg_Etc_GS, alg_UCB_GS,alg_QMB]
    # algorithms=[alg_Elim2]
    algorithms=[alg_Elim,alg_Elim2,alg_UCB,alg_QMB,alg_TSQMB]
    # algorithms=[alg_Elim,alg_Elim2]

    for algorithm in algorithms:
        name=algorithm.name()
        exp_reward[name]=[]       
        Env_dict[name]=linear_Env(seed,d,N,K,L)
    for algorithm in algorithms:
        algorithm.reset()
        name=algorithm.name()
        print(name)
        start = timeit.default_timer()

        for t in tqdm((np.array(range(T))+1)):
            
            stop = timeit.default_timer()
            # if stop-start>3600:
            #     print(name,t,i)
            #     print('break')
            #     break
            if t==1:
                algorithm.run(t,np.zeros(K))
            else:
                algorithm.run(t,index[name])
            S=algorithm.offer()
            # print('self.S',S)

            index[name]=Env_dict[name].observe(S)
            exp_reward[name].append(Env_dict[name].exp_reward(S))

        stop = timeit.default_timer()
        print('Time: ', stop - start)
        with open(f'./result/{name}_N_{N}K_{K}_repeat{i}_time2.txt', 'wb') as f:
            pickle.dump(stop-start, f)
            f.close()

    for algorithm in algorithms:
        name=algorithm.name()
        
        reg=oracle_reward-exp_reward[name]
        regret_sum=np.cumsum(reg)
        # regret_sum_list[name][i,:]=np.zeros(T,float)+regret_sum
        print('regret_sum',regret_sum)
        regret_sum_list[name][i, :len(regret_sum)] = regret_sum
        regret_sum_list[name][i, len(regret_sum):] = -1
        print('regret_sum_list',regret_sum_list[name][i, :])
        # avg_regret_sum[name]+=regret_sum  


        filename_1=name+'T'+str(T)+'d'+str(d)+'N'+str(N)+'K'+str(K)+'L'+str(L)+'repeat'+str(i)+'regret2.txt'
        with open('./result/'+filename_1, 'wb') as f:
            pickle.dump(regret_sum, f)
            f.close()
            
    exp_reward.clear()
    # avg_regret_sum.clear()
    regret_sum_list.clear()
    std.clear()
    index.clear()
    Env_dict.clear()
    S.clear()


def run_multiprocessing(T, repeat, d, N, K, L):
    Path("./result").mkdir(parents=True, exist_ok=True)

    num_processes = multiprocessing.cpu_count()
    with multiprocessing.Pool(processes=num_processes) as pool:
    # for i in range(repeat):
        # run(T, repeat, d, N, K, L, i)    
        pool.starmap(run, [( T, repeat, d, N, K, L, i) for i in range(repeat)])
    pool.close()
    pool.join()
    

if __name__=='__main__':
    L=2  
    T=5000
    d=2
    repeat=10
    # for N in [7]:
    #     for K in [4]:
    #         print(N,K)
    #         run_multiprocessing(T,repeat,d,N,K,L)
    for N in [7]:
        for K in [4]:
            print(N,K)
            run_multiprocessing(T,repeat,d,N,K,L)
    # T=100000
    # for N in [8]:
    #     for K in [5]:
    #         print(N,K)
    #         run_multiprocessing(T,repeat,d,N,K,L)
    # for N in [4]:
    #     for K in [2]:
    #         print(N,K)
    #         run_multiprocessing(T,repeat,d,N,K,L)
    


