import os
import random
from env import *
from algorithms import *
import time
from multiprocessing import Pool

def run(args):
    model_list, seed = args
    run_returns = []
    times = []
    for model_name, model, params in model_list:
        random.seed(seed)
        agent = model(*params)
        start_time = time.process_time()
        episodic_return = agent.run()
        t = time.process_time() - start_time
        run_returns.append(episodic_return)
        times.append(t)
        # print(model_name, t1, np.sum(episodic_return))
    return run_returns, times

runs = 10

K = 100000
delta = 0.05

MDPtype = "Riverswim"
S = 10
H = 40
if MDPtype == "Riverswim":
    env = make_riverSwim(epLen=H, nState=S)
else:
    print("Unknown MDP type")
    exit(0)

optval = env.calculateOptimalValue()[0, 0]

uniform = False

# Create target Directory if don't exist
out_dir = f"./data/S{S}H{H}_" +("" if uniform else "non") + "uniform" 
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print("Directory " , out_dir ,  " Created ")
else:    
    print("Directory " , out_dir ,  " already exists")

model_list = [
    ('UCRL2', UCRL2, (env, K, delta, uniform)),
    ('UCBVI-BF', UCBVI_Azar, (env, K, delta, uniform)),
    ('Euler', Euler, (env, K, delta, uniform)),
    ('ORLC', ORLC, (env, K, delta, uniform)),
    ('MVP', MVP, (env, K, delta, uniform)),
    ('EQO' , EQO_Kunaware, (env, K, delta, uniform)),
]

with Pool(min(runs, 10)) as pool:
    results = pool.map(run, [ (model_list, (i+1)*1234) for i in range(runs)])

run_returns = np.asarray([results[i][0] for i in range(runs)])

times = [ sum([results[j][1][i] for j in range(runs)]) / runs for i in range(len(model_list)) ]

with open(f"{out_dir}/{MDPtype}_time.txt", "a") as myfile:
    for i, model in enumerate(model_list):
        myfile.write(f"{model[0]} {times[i]}\n")

filename = f"{out_dir}/{MDPtype}"

run_regrets = np.cumsum(optval - run_returns, axis = 2)
avg_regret = np.mean(run_regrets, axis = 0)
std_regret = np.std(run_regrets, axis = 0)

np.savetxt(f"{filename}.csv", np.concatenate((avg_regret, std_regret), axis = 0))
