import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import tqdm
import time
import numpy as np
from matplotlib import pyplot as plt
import mdp
from utils import reduce_curve, set_seed, time_str
from environments.riverswim import RiverSwim
from learners.agent import iterate_algorithm, parse_history
from learners.evi import EVI_based, set_BIAS_SPAN, set_BIAS

# Model Parameters
S = 3
A = 2

# Plotting Parameters
T    = 5 * int(1e3)
RUNS = 100
N_POINTS = 30

# set_seed(3)

print("\n>>> Model")
model = RiverSwim(S)
print(model)

h = model.bias()
set_BIAS(h)


print("\n>>> Initializing learners ...")
learners = [ EVI_based(model, config=EVI_based.prefabs["UCRL2"]), ]
cs = [0, 0.5, 1.0, 1.5, 2.0]
for i, c in enumerate(cs):
    learner = EVI_based(model, config=EVI_based.prefabs["PMEVI"])
    learner.set_name(f"PMEVI(c={c})")
    learner.set_bias_prior([
        (0, 1, -c),
        (1, 2, -c),
    ])
    learners.append(learner)

markers = ["o", "x", "+", "s", "^", "*", "P"]
regrets = { learner.name(): np.zeros(T+1) for learner in learners }
 
t0 = time.time()
for run in range(RUNS):
    t_spend = time.time() - t0
    t_rem = (RUNS - run) * t_spend/max(run, 1)
    print(f"Run {run+1} ... (spend {time_str(t_spend)}, remains {time_str(t_rem)})")
    for learner in learners:
        learner.reset(model)
        name = learner.name()
        history = [0]
        for _ in tqdm.tqdm(range(T), desc=name):
            iterate_algorithm(model, learner, history)
        history.pop()
        greg = np.array(parse_history(model, history)["gap regret"])
        regrets[name] += greg

print(">>> Rendering ... ")
fig, ax = plt.subplots()
X = reduce_curve(list(range(T+1)), mode="linear", num=N_POINTS)
max_reg = 0.0
for marker, learner in zip(markers, learners):
    name = learner.name()
    Y = regrets[name] / RUNS
    max_reg = max(max_reg, max(Y))
    Y = reduce_curve(Y, mode="linear", num=N_POINTS)
    ax.plot(X, Y, label=name, marker=marker, fillstyle="none", color="k")

# ax.set_xscale("log")
ax.set_xlim(10, max(X))
ax.set_ylim(0, 1.33 * max_reg)
ax.set_xlabel("Time $t$")
ax.set_ylabel(f"${{\\rm Reg}}(t)$ averaged over {RUNS} runs")
ax.legend()
plt.show()


