import numpy as np
import matplotlib.pyplot as plt
import datetime, time, json, os, sys, multiprocessing

from environments import Environment
from labelers import StochasticLabeler
from algorithms import SERE

if __name__ == '__main__':
    with open('configs/' + sys.argv[1] + '.json') as json_file:
        config = json.load(json_file)

    start_time = datetime.datetime.now()
    start_print = start_time.strftime("(%Y-%b-%d %I:%M%p)")

    mdp_index = config["mdp"]
    S = np.array(config["S"])[mdp_index]
    A = np.array(config["A"])[mdp_index]
    H = config["H"]
    epsi = np.array(config["epsilon"])
    lamb = np.array(config["lambda"])[2]
    delt = config["delta"]

    if "r_min" in config:
        r_min = config["r_min"]
    else:
        r_min = 0

    if "r_max" in config:
        r_max = config["r_max"]
    else:
        r_max = 1

    if "seed" in config:
        seeds = np.array(config["seed"])
    else:
        seeds = np.array([0])

    results = {}
    for k in config:
        results[k] = config[k]
    results["V_opt"] = np.zeros(seeds.size)
    results["V_est"] = np.zeros((epsi.size, seeds.size))

    print(f"{start_print} Starting run with config file : 'configs/{sys.argv[1]}.json'")

    iter = 0
    for i in range(seeds.size):
        seed = seeds[i]
        np.random.seed(seed)

        env = Environment(S=S, A=A, H=H, r_max=r_max, r_min=r_min)
        labeler = StochasticLabeler(env=env)

        for j in range(epsi.size):
            e = epsi[j]
            print(f"Run {iter}/{seeds.size-1}: epsilon {e} --- START ... ", end='')
            alg = SERE(S=S, A=A, H=H, e=e, l=lamb, d=delt, labeler=labeler, extra_args=config)
            alg.run()

            r_est = alg.rew_est
            V_opt, V_est = env.compare_policy_value_true_est(r_est)
            results["V_opt"][i] = V_opt[:,0]@env.mu
            results["V_est"][j,i] = V_est[:,0]@env.mu

            print(f"END")

        iter = iter+1
    
    end_time = datetime.datetime.now()
    end_print = end_time.strftime("(%Y-%b-%d %I:%M%p)")
    print(f"{end_print} Ended run. Elapsed time: {end_time-start_time}")

    filename = f"SA_{S}-Rmax_{r_max}_epsilon.npy"

    np.save("results/" + filename, results)