import numpy as np
import matplotlib.pyplot as plt
from algs.psrl import psrl, Prior
from envs.tabular_mdp import TabularMDP

def random_tabular(nS, nA, T):
    mdp = TabularMDP(nS, nA, T)
    # random transitions
    P = np.zeros(shape=(nS * nA, nS))
    for i in range(nS * nA):
        P[i] = np.random.dirichlet(np.ones(nS))
    mdp.setP(P)
    # reward function
    R = np.random.rand(nS * nA)
    mdp.setR(R)

    return mdp

nS = 10
nA = 2
T = 100
n_iters = 1000
mdp = random_tabular(nS, nA, T)
prior = Prior(mdp.nS, mdp.nA)
policy, regret = psrl(mdp, prior, n_iters)

# plot
steps = np.array(range(n_iters))
ub = T * np.sqrt(nS * nA * steps)
plt.plot(regret)
# plt.plot(ub)
plt.show()