import numpy as np
import random

from envs.tabular_mdp import TabularMDP
from envs.factored_mdp import FactoredMDP

def random_tabular(nS, nA, T):
    mdp = TabularMDP(nS, nA, T)
    # random transitions
    P = np.zeros(shape=(nS * nA, nS))
    for i in range(nS * nA):
        P[i] = np.random.dirichlet(np.ones(nS))
    mdp.setP(P)
    # reward function
    R = np.random.rand(nS * nA)
    mdp.setR(R)

    return mdp

def random_factored(d_S, d_A, Z, n, T):
    fmdp = FactoredMDP(d_S, d_A, Z, n, T)
    # random scopes
    scopes = []
    for j in range(d_S):
        scope = np.random.choice(range(d_S + d_A), 1 + np.random.choice(range(Z)), replace=False) # to-do: enforce at least one action
        scopes.append(np.sort(scope))
    fmdp.setscopes(scopes)
    # random transitions
    factors = []
    for Z_j in scopes:
        P_j = np.zeros(shape=(n ** len(Z_j), n))
        for i in range(n ** len(Z_j)):
            P_j[i] = np.random.dirichlet(np.ones(n))
        factors.append(P_j)
    fmdp.setP(factors)
    # reward function
    r_factors = []
    for Z_j in scopes:
        R_j = np.random.rand(n ** len(Z_j))
        r_factors.append(R_j)
    fmdp.setR(r_factors, scopes)
    return fmdp

def random_factored_sparse(d_S, d_A, Z, n, T, sparse=0.3):
    fmdp = FactoredMDP(d_S, d_A, Z, n, T)
    # random scopes
    scopes = []
    for j in range(d_S):
        scope = np.random.choice(range(d_S + d_A), 1 + np.random.choice(range(Z)), replace=False) # to-do: enforce at least one action
        scopes.append(np.sort(scope))
    fmdp.setscopes(scopes)
    # random transitions
    factors = []
    for Z_j in scopes:
        P_j = np.zeros(shape=(n ** len(Z_j), n))
        for i in range(n ** len(Z_j)):
            P_j[i] = np.random.dirichlet(np.ones(n))
        factors.append(P_j)
    fmdp.setP(factors)
    # reward function
    r_factors = []
    for Z_j in scopes:
        if np.random.choice([0, 1], 1, p=[1 - sparse, sparse]):
            R_j = np.random.rand(n ** len(Z_j))
        else:
            R_j = np.zeros(n ** len(Z_j))
        r_factors.append(R_j)
    fmdp.setR(r_factors, scopes)
    return fmdp