import numpy as np
import matplotlib.pyplot as plt
from algs.taxi.psrl import psrl, Prior
from envs.taxi.taxi2 import TaxiEnv
from envs.tabular_mdp import TabularMDP
from envs.taxi.factored_mdp import FactoredMDP, factored_to_tabular, ind_to_x
# from envs.taxi.taxi import TaxiEnv
from tabular_mdp import TabularMDP

d_S = 4
d_A = 1
Z = 5
n_iters = 400
n = np.array([5,5,2,1])
nactions = np.array([6])
T = 15
n_runs = 20
nn = np.concatenate((n,nactions))
seeds = np.load("/Users/giorgiaramponi/Desktop/documents/Github/causal-psrl/scripts/taxi/seeds.npy")

mdp = TaxiEnv()
nS = mdp.num_states
nA = mdp.num_actions
mdp_tabular = TabularMDP(nS, nA, T)
prior = Prior(mdp_tabular.nS, mdp_tabular.nA)
P, R = mdp.get_transitions_rewards()
mdp_tabular.P = P
mdp_tabular.R = R
mdp_tabular.mu = np.zeros(nS)
mdp_tabular.mu[0] = 1
# mdp_tabular.mu[1] = 1
# mdp_tabular.mu[2] = 1
# mdp_tabular.mu[3] = 1
mdp_tabular.mu = mdp_tabular.mu/mdp_tabular.mu.sum()

# mdp_tabular.mu[0] = 1
regrets = np.zeros((n_runs,n_iters))
for i in range(n_runs):
    print(i)
    np.random.seed(seeds[i])
    prior = Prior(mdp_tabular.nS, mdp_tabular.nA)
    policy, regret = psrl(mdp_tabular, prior, n_iters)
# plot
    steps = np.array(range(n_iters))
    ub = T * np.sqrt(nS * nA * steps)
    regrets[i] = regret
    np.save("regret_taxi_psrl_final_"+str(i), regrets[i])

plt.plot(regrets.mean(axis=0))
# plt.plot(ub)
plt.show()