import numpy as np

from envs.tabular_mdp import TabularMDP
import time

def ind_to_x(ind, val, n, d):
    term_base = d / n[val[0]]
    mod_idx = ind
    x = np.zeros(len(val))
    x[0] = int(np.floor(mod_idx / term_base).astype(int))
    for i in range(len(val)-1):
        mod_idx %= term_base
        term_base /= n[val[i+1]]
        x[i+1] = int(np.floor(mod_idx / term_base).astype(int))
    return x.astype((int))

# convert from features to index
def x_to_ind(x, val, n):
    sum_idx = 0
    prev_base = 1
    valid_n = []
    for v in val:
        valid_n.append(n[v])
    for i in range(len(x)):
        idx = len(x) - 1 - i
        sum_idx += prev_base * x[idx]
        prev_base *= valid_n[idx]
    return int(sum_idx)

def factored_to_tabular(fmdp):
    # nS, nA, T and create the object
    nS = fmdp.n.prod()
    nA = fmdp.nactions.prod()
    T = fmdp.T
    mdp = TabularMDP(nS, nA, T)
    # set P
    nn = fmdp.nactions
    n = fmdp.n
    P = np.zeros(shape=(nS * nA, nS))
    for s in range(nS):
        for a in range(nA):
            sa = s * nA + a
            # print(fmdp.n)
            x = ind_to_x(sa, range(fmdp.d_S+fmdp.d_A), np.concatenate((n,nn)), nS*nA)
            for ns in range(nS):
                nsf = ind_to_x(ns, range(fmdp.d_S), n, fmdp.nS)
                P[sa][ns] = prod_probabilities(fmdp, x, nsf)
    mdp.setP(P)
    # set R
    # R = mdp_star.R

    # mdp.setR(R)
    # set mu
    mdp.setmu(fmdp.mu_features)
    return mdp

def prod_probabilities(fmdp, x, nsf):
    factors = fmdp.factors
    scopes = fmdp.scopes
    prod = 1
    n = fmdp.n
    nn = fmdp.nactions
    for j, Z_j in enumerate(scopes):
        P_j = factors[j]
        ind = x_to_ind(x[Z_j],Z_j, np.concatenate((n, nn)))
        prod = prod * P_j[ind][int(nsf[j])]
    return prod



# convert from factored to tabular



class FactoredMDP():

    def __init__(self, d_S, d_A, Z, R, n, nactions, T):
        self.d_S = d_S
        self.d_A = d_A
        self.d_X = d_S + d_A
        self.Z = Z
        self.n = np.array(n) #support of the features
        self.nactions = np.array(nactions)
        self.T = T
        # initial features

        self.mu_features = np.zeros(self.n.prod())
        self.mu_features[0] = 1
        # states and actions
        self.nS = self.n.prod()
        self.nA = self.nactions.prod()
        self.R = R.reshape((self.nS,self.nA))


    def setscopes(self, scopes):
        self.scopes = scopes

    def setP(self, factors):
        self.factors = factors

    def setR(self, r_factors, r_scopes):
        self.r_factors = r_factors
        self.r_scopes = r_scopes

    def setmu(self, mu):
        self.mu_features = mu

    def step(self, a):
        if self.t < self.T:
            x = np.concatenate((self.s, a))
            ns = np.ndarray(shape=(self.d_S), dtype=int)
            for j in range(self.d_S):
                # compute next state feature
                x_j = x[self.scopes[j]]
                # print(x_j, self.scopes[j])
                ind_j = x_to_ind(x_j, self.scopes[j],np.concatenate((self.n, self.nactions)))
                # print("time 2",j,t_-time.time())
                P_j = self.factors[j]
                ns[j] = np.random.choice(range(self.n[j]), p=P_j[ind_j])
                # print("ns",ns[j], j)
                # compute reward
            s = x_to_ind(self.s, np.arange(self.d_S),self.n)
            r = self.R[s,a]
            # print("time 3",j,t_ - time.time())
            out = (self.t, self.s, a, ns, r)
            self.s = ns
            self.t = self.t + 1
        else:
            out = None
        return out

    def reset(self):
        self.t = 0
        ind = np.random.choice(range(len(self.mu_features)), p=self.mu_features)
        self.s = ind_to_x(ind, np.arange(self.d_S),self.n, self.nS)
        return self.s

    def seed(self, seed=None):
        np.random.seed(seed)