import numpy as np
import matplotlib.pyplot as plt
import datetime, time, json, os, sys, multiprocessing

from environments import Environment
from labelers import StochasticLabeler
from algorithms import SERE

if __name__ == '__main__':
    with open('configs/' + sys.argv[1] + '.json') as json_file:
        config = json.load(json_file)

    start_time = datetime.datetime.now()
    start_print = start_time.strftime("(%Y-%b-%d %I:%M%p)")

    mdp_index = config["mdp"]
    S = np.array(config["S"])[mdp_index]
    A = np.array(config["A"])[mdp_index]
    H = config["H"]
    epsi = np.array(config["epsilon"])[0]
    lamb = np.array(config["lambda"])
    delt = config["delta"]

    if "r_min" in config:
        r_min = config["r_min"]
    else:
        r_min = 0

    if "r_max" in config:
        r_max = config["r_max"]
    else:
        r_max = 1

    if "seed" in config:
        seeds = np.array(config["seed"])
    else:
        seeds = np.array([0])

    results = {}
    for k in config:
        results[k] = config[k]
    results["true_index"] = np.zeros((H, seeds.size))
    results["est_index"] = np.zeros((np.array(lamb).size, H, seeds.size))
    results["t_stop"] = np.zeros((np.array(lamb).size, H, seeds.size))

    print(f"{start_print} Starting run with config file : 'configs/{sys.argv[1]}.json'")

    iter = 0
    for i in range(seeds.size):
        seed = seeds[i]
        np.random.seed(seed)

        env = Environment(S=S, A=A, H=H, r_max=r_max, r_min=r_min)
        labeler = StochasticLabeler(env=env)

        for j in range(lamb.size):
            l = lamb[j]
            print(f"Run {iter}/{seeds.size-1}: lambda {l} --- START ... ", end='')
            alg = SERE(S=S, A=A, H=H, e=epsi, l=l, d=delt, labeler=labeler, extra_args=config)
            alg.run()

            for h in range(H):
                est_B = alg.basis[:,:,h]
                env_B = env.get_optimal_basis(h)

                true_index = env.get_true_index(env_B, h)
                est_index = env.get_true_index(est_B, h)

                results["true_index"][h, i] = true_index
                results["est_index"][j, h, i] = est_index
                results["t_stop"][j,h,i] = alg.t_stop[h]

            print(f"END")

        iter = iter+1
    
    end_time = datetime.datetime.now()
    end_print = end_time.strftime("(%Y-%b-%d %I:%M%p)")
    print(f"{end_print} Ended run. Elapsed time: {end_time-start_time}")
    filename = f"SA_{S}-Rmax_{r_max}_lambda.npy"

    np.save("results/" + filename, results)