import numpy as np
import matplotlib.pyplot as plt
import random

from envs.factored_mdp import FactoredMDP, factored_to_tabular
from algs.factored_psrl import factored_psrl, FactoredPrior
from algs.causal_psrl import causal_psrl, HierarchicalPrior, generate_causal_prior


def random_factored(d_S, d_A, Z, n, T):
    fmdp = FactoredMDP(d_S, d_A, Z, n, T)
    # random scopes
    scopes = []
    for j in range(d_S):
        scope = random.sample(range(d_S + d_A), 1 + random.choice(range(Z)))
        scopes.append(np.sort(scope))
    fmdp.setscopes(scopes)
    # random transitions
    factors = []
    for Z_j in scopes:
        P_j = np.zeros(shape=(n ** len(Z_j), n))
        for i in range(n ** len(Z_j)):
            P_j[i] = np.random.dirichlet(np.ones(n))
        factors.append(P_j)
    fmdp.setP(factors)
    # reward function
    r_factors = []
    for Z_j in scopes:
        R_j = np.random.rand(n ** len(Z_j))
        r_factors.append(R_j)
    fmdp.setR(r_factors, scopes)
    return fmdp

d_S = 3
d_A = 2
Z = 5
n = 2
T = 10
fmdp = random_factored(d_S, d_A, Z, n, T)

n_iters = 1000
# prior = FactoredPrior(fmdp.scopes, fmdp.n)
# policy, regret, model_error = factored_psrl(fmdp, prior, n_iters)
print(fmdp.Z, fmdp.d_S, fmdp.scopes)
causal_prior = generate_causal_prior(fmdp.Z, fmdp.scopes, 2)
print(causal_prior)
prior = HierarchicalPrior(fmdp.d_S, fmdp.Z, fmdp.n, causal_prior)
print(prior)
_, regret, error, value = causal_psrl(fmdp, prior, n_iters, eps_greedy=0.2, eps=0.3)
# plot
steps = np.array(range(n_iters))
plt.plot(regret)
plt.show()
# plt.plot(model_error)
# plt.show()