import numpy as np

from envs.tabular_mdp import TabularMDP

# convert from index to features
def ind_to_x(ind, n, d):
    x = []
    while ind > 0:
        x.append(ind % n)
        ind = ind // n
    # filling zeros
    x = np.concatenate((np.array(x, dtype=int), np.zeros(d - len(x), dtype=int)))
    return x[::-1]

# convert from features to index
def x_to_ind(x, n):
    coeff = np.array([n ** i for i in range(len(x))])
    coeff = coeff[::-1]
    return int(np.matmul(x, coeff))

def prod_probabilities(fmdp, x, nsf):
    factors = fmdp.factors
    scopes = fmdp.scopes
    prod = 1
    for j, Z_j in enumerate(scopes):
        P_j = factors[j]
        ind = x_to_ind(x[Z_j], fmdp.n)
        prod = prod * P_j[ind][nsf[j]]
    return prod

def sum_rewards(fmdp, x):
    r_factors = fmdp.r_factors
    r_scopes = fmdp.r_scopes
    sum = 0
    for j, Z_j in enumerate(r_scopes):
        R_j = r_factors[j]
        ind = x_to_ind(x[Z_j], fmdp.n)
        sum = sum + R_j[ind]
    return sum

# convert from factored to tabular
def factored_to_tabular(fmdp):
    # nS, nA, T and create the object
    nS = fmdp.n ** fmdp.d_S
    nA = fmdp.n ** fmdp.d_A
    T = fmdp.T
    mdp = TabularMDP(nS, nA, T)
    # set P
    P = np.zeros(shape=(nS * nA, nS))
    for s in range(nS):
        for a in range(nA):
            sa = s * nA + a
            x = ind_to_x(sa, fmdp.n, fmdp.d_X)
            for ns in range(nS):
                nsf = ind_to_x(ns, fmdp.n, fmdp.d_S)
                P[sa][ns] = prod_probabilities(fmdp, x, nsf)
    mdp.setP(P)
    # set R
    R = np.zeros(shape=(nS * nA))
    for s in range(nS):
        for a in range(nA):
            sa = s * nA + a
            x = ind_to_x(sa, fmdp.n, fmdp.d_X)
            R[sa] = sum_rewards(fmdp, x)
    mdp.setR(R)
    # set mu
    mdp.setmu(fmdp.mu_features)
    return mdp



class FactoredMDP():

    def __init__(self, d_S, d_A, Z, n, T):
        self.d_S = d_S
        self.d_A = d_A
        self.d_X = d_S + d_A
        self.Z = Z
        self.n = n #support of the features
        self.T = T
        # initial features
        self.mu_features = np.zeros(n ** d_S)
        self.mu_features[0] = 1
        # states and actions
        self.nS = n ** d_S
        self.nA = n ** d_A

    def setscopes(self, scopes):
        self.scopes = scopes

    def setP(self, factors):
        self.factors = factors

    def setR(self, r_factors, r_scopes):
        self.r_factors = r_factors
        self.r_scopes = r_scopes

    def setmu(self, mu):
        self.mu_features = mu

    def step(self, a):
        if self.t < self.T:
            x = np.concatenate((self.s, a))
            ns = np.ndarray(shape=(self.d_S), dtype=int)
            r = 0
            for j in range(self.d_S):
                # compute next state feature
                x_j = x[self.scopes[j]]
                ind_j = x_to_ind(x_j, self.n)
                P_j = self.factors[j]
                ns[j] = np.random.choice(range(self.n), p=P_j[ind_j])
                # compute reward
                x_j = x[self.r_scopes[j]]
                ind_j = x_to_ind(x_j, self.n)
                r = r + self.r_factors[j][ind_j]
            out = (self.t, self.s, a, ns, r)
            self.s = ns
            self.t = self.t + 1
        else:
            out = None
        return out

    def reset(self):
        self.t = 0
        ind = np.random.choice(range(len(self.mu_features)), p=self.mu_features)
        self.s = ind_to_x(ind, self.n, self.d_S)
        return self.s

    def seed(self, seed=None):
        np.random.seed(seed)