import numpy as np
import matplotlib.pyplot as plt
from envs.taxi.taxi2 import TaxiEnv
from envs.taxi.factored_mdp import FactoredMDP
from algs.taxi.causal_psrl import causal_psrl, HierarchicalPrior, generate_causal_prior



d_S = 4
d_A = 1
Z = 5
n_iters = 400
n = np.array([5,5,2,1])
nactions = np.array([6])
T = 15
n_runs = 20
nn = np.concatenate((n,nactions))
seeds = np.load("/causal-psrl/scripts/taxi/seeds.npy")
print(seeds.shape)

mdp = TaxiEnv()
P,R = mdp.get_transitions_rewards()
factors = mdp.factored_P
scopes = mdp.scopes
nS = mdp.num_states
nA = mdp.num_actions
fmdp = FactoredMDP(d_S, d_A, Z, R, n, nactions, T)
fmdp.setscopes(scopes)
fmdp.setP(factors)
mu = np.zeros(fmdp.nS)
mu[0] = 1
fmdp.setmu(mu=mu)
previous_edges = [[] for _ in range(len(fmdp.scopes))]

regrets = np.zeros((n_runs,n_iters))
for i in range(n_runs):
    print(i)
    np.random.seed(seeds[i])
    causal_prior = generate_causal_prior(fmdp.Z, fmdp.scopes, 2)
    prior = HierarchicalPrior(fmdp.d_S, fmdp.d_A, fmdp.Z, np.concatenate((fmdp.n, fmdp.nactions)), causal_prior)
    _, r, error, value = causal_psrl(fmdp,R, prior, n_iters)
    regrets[i] = r
    np.save("regret_taxi_causal_final_"+str(i), regrets)
plt.plot(regrets.mean(axis=0))
plt.show()



