from Algorithms import *
from Environment import *
from Environment import ProphetInequalityEnv
import numpy as np
import pickle
from pathlib import Path
from tqdm import tqdm
import multiprocessing



def run(repeat,d,n,i,noniid,noise_std=0.1):
    Q_sum=dict()
    avg_Q_sum=dict()
    oracle_reward=0
    Q_sum_list=dict()
    std=dict()
    index=dict()
    Q=dict()
    Env_dict=dict()
    S=dict()
    exp_reward=dict()
    avg_regret_sum=dict()
    oracle_reward=dict()
    regret_sum_list=dict()
    std_R=dict()
    for algorithm in ['ETD-LCBT(iid)','ε-Greedy-LCBT','ETD-LCBT(non-iid)','ETD-LCBT-WA', 'Gusein-Zade']:
        exp_reward[algorithm]=[]
        oracle_reward[algorithm]=[] 
        regret_sum_list[algorithm]=np.zeros((repeat,1),float)
        avg_regret_sum[algorithm]=np.zeros(1,float)
        std_R[algorithm]=np.zeros(1,float)
        
    print('repeat',i)
    seed=i
    # Env=ProphetInequalityEnv(seed,d,n)

    # # X=Env.items
    # alg_LCB=ETD_LCBT(seed,d,n)
    # alg_greedy=greedy(seed,d,n)


    if noniid==True:
        Env=Noniid_ProphetInequalityEnv(seed,d,n,noise_std)
        alg_LCBT_Noniid=ETD_LCBT_NonIID(seed,d,n)
        alg_LCBT_Noniid_Window=ETD_LCBT_NonIID_Window(seed,d,n)
        alg_Oracle=Oracle_NonIID(seed,d,n)
        alg_Secretary=Secretary(seed,d,n)

    else:
        # Env=ProphetInequalityEnv(seed,d,n)
        alg_LCBT=ETD_LCBT(seed,d,n)
        alg_greedy=greedy(seed,d,n)
        alg_Oracle=Oracle(seed,d,n)
        alg_Secretary=Secretary(seed,d,n)
        # alg_LCBT_Noniid_Window=
    # algorithms=[alg_LCB, alg_greedy]
    if noniid==True:
        algorithms=[alg_LCBT_Noniid, alg_LCBT_Noniid_Window,  alg_Secretary]
        # algorithms=[alg_Secretary]
        for algorithm in algorithms:
            name=algorithm.name()
            exp_reward[name]=[]              
            Env_dict[name]=Noniid_ProphetInequalityEnv(seed,d,n,noise_std)
    else:
        algorithms=[alg_LCBT, alg_greedy, alg_Secretary]
        # algorithms=[alg_greedy]

        for algorithm in algorithms:
            name=algorithm.name()
            exp_reward[name]=[]              
            Env_dict[name]=ProphetInequalityEnv(seed,d,n,noise_std)


    for algorithm in algorithms:
        Env = Env_dict[algorithm.name()]
        algorithm.reset()
        name=algorithm.name()
        print(name)
        for t in tqdm((np.array(range(n)))):
            # if t==0:
            x=Env.get_item(t)
            y=Env.recommend_and_feedback(t)
            if noniid==True:
                l=Env.get_inform_dis()[0]
                h=Env.get_inform_dis()[1]
                if name == 'Oracle':
                    algorithm.run(t,x,y,Env.theta,l,h)
                elif name == 'Gusein-Zade':
                    algorithm.run(t,x,y)
                else:
                    algorithm.run(t,x,y,l,h)
            else:
                if name == 'Oracle':
                    algorithm.run(t,x,y,Env.theta)
                else:
                    algorithm.run(t,x,y)
            # else:
            #     algorithm.run(t,x, Q[name])
            if algorithm.stopped:
                if algorithm.tau==n:
                    exp_reward[name].append(0)
                else:
                    exp_reward[name].append(Env.stop_and_choose(algorithm.tau))
                oracle_reward[name].append(Env.get_optimal_reward())
                break
            # exp_reward[name].append(Env_dict[name].exp_reward(S,Q[name]))
            # oracle_reward[name].append(Env_dict[name].exp_reward(S_MW,Q[name]))
            # Q_sum[name].append(sum(Q[name]))
            # index[name],Q[name]=Env_dict[name].observe(S)


    # for algorithm in algorithms:
    #     name=algorithm.name()
    #     Q_sum_cum=np.cumsum(Q_sum[name])
    #     indexes = np.arange(1, len(Q_sum[name]) + 1)
    #     Q_sum_cum_avg=Q_sum_cum/ indexes
    #     Q_sum_list[name][i,:]=Q_sum_cum_avg
    #     avg_Q_sum[name]+=Q_sum_cum_avg

    #     reg=np.array(oracle_reward[name])-np.array(exp_reward[name])
    #     regret_sum=np.cumsum(reg)
    #     regret_sum_list[name][i,:]=regret_sum
    #     avg_regret_sum[name]+=regret_sum  


        filename_1=name+'n'+str(n)+'d'+str(d)+'repeat'+str(i)+'noise_std'+str(noise_std)+'alg.txt'
        with open('./result/'+filename_1, 'wb') as f:
            pickle.dump(exp_reward[name], f)
            f.close()   


        filename_1=name+'n'+str(n)+'d'+str(d)+'repeat'+str(i)+'noise_std'+str(noise_std)+'oracle.txt'
        with open('./result/'+filename_1, 'wb') as f:
            pickle.dump(oracle_reward[name], f)
            f.close()

    # Q_sum.clear()
    # avg_Q_sum.clear()
    # Q_sum_list.clear()
    # std.clear()
    # index.clear()
    # Q.clear()
    # Env_dict.clear()
    # S.clear()
    # exp_reward.clear()
    # avg_regret_sum.clear()
    # oracle_reward.clear()
    # regret_sum_list.clear()
    # std_R.clear()

def run_multiprocessing(repeat,d,n,noniid,noise_std=0.1):
    Path("./result").mkdir(parents=True, exist_ok=True)
    
        
    num_processes = multiprocessing.cpu_count()
    with multiprocessing.Pool(processes=num_processes) as pool:
        pool.starmap(run, [( repeat, d,n,i,noniid,noise_std) for i in range(repeat)])

    pool.close()
    pool.join()

if __name__=='__main__':
    
    d=2  
    repeat=10
    noniid=True
    noise_std=0.8
    if noniid==True:
        n=30000
    else:
        n=100000 
    print(d,n,repeat,noise_std)
    run_multiprocessing(repeat,d,n,noniid,noise_std)
