import numpy as np
import copy

from envs.factored_mdp import x_to_ind, ind_to_x, factored_to_tabular

# class to instantiate the prior
class FactoredPrior():

    def __init__(self, scopes, n):
        self.scopes = scopes
        self.n = n
        self.params = []
        for Z_j in scopes:
            params_j = np.ones(shape=(n ** len(Z_j), n))
            self.params.append(params_j)

    def draw(self):
        n = self.n
        scopes = self.scopes
        factors = []
        for j, Z_j in enumerate(scopes):
            P_j = np.zeros(shape=(n ** len(Z_j), n))
            for i in range(n ** len(Z_j)):
                P_j[i] = np.random.dirichlet(self.params[j][i])
            factors.append(P_j)
        return factors 

    def update(self, masks):
        for j in range(len(self.scopes)):
            self.params[j] = self.params[j] + masks[j]


# value iteration
def value_iteration(mdp):
    nS = mdp.nS
    nA = mdp.nA
    P = mdp.P
    R = mdp.R.reshape(nS, nA)
    T = mdp.T
    mu = mdp.mu
    # q-iteration
    Q = np.zeros(shape=(T, nS, nA))
    policy = np.zeros(shape=(T, nS, nS * nA))
    V = np.zeros(shape=(T, nS))
    for t in range(T - 1, -1, -1):
        if t == T - 1:
            Q[t] = R
        else:
            Q[t] = R + np.dot(P, np.max(Q[t + 1], axis=1)).reshape(nS, nA)
        for s in range(nS):
            a_max = np.argmax(Q[t][s])
            policy[t][s][s * nA + a_max] = 1.
        V[t] = np.dot(policy[t], Q[t].reshape(nS * nA))
    # expected return
    value = np.dot(mu, V[0])
    
    return (policy, value.item())

# policy evaluation
def policy_evaluation(mdp, policy):
    nS = mdp.nS
    nA = mdp.nA
    P = mdp.P
    R = mdp.R.reshape(nS, nA)
    T = mdp.T
    mu = mdp.mu
    # q-iteration
    Q = np.zeros(shape=(T, nS, nA))
    V = np.zeros(shape=(T, nS))
    for t in range(T - 1, 0, -1):
        if t == T - 1:
            V[t] = np.dot(policy[t], R.reshape(nS * nA))
        else:
            V[t] = np.dot(policy[t], Q[t].reshape(nS * nA))
        Q[t - 1] = R + np.dot(P, V[t]).reshape(nS, nA)
    # expected return
    V[0] = np.dot(policy[0], Q[0].reshape(nS * nA))
    value = np.dot(mu, V[0])
    
    return value.item()

# collect data
def collect_data(fmdp, policy, eps_greedy=False, eps=1e-2):
    n = fmdp.n
    masks = []
    for Z_j in fmdp.scopes:
        masks.append(np.zeros(shape=(n ** len(Z_j), n)))
    # interaction loop
    data = []
    s = fmdp.reset()
    for t in range(fmdp.T):
        # draw action
        s_ind = x_to_ind(s, n)
        probs = policy[t][s_ind][s_ind * fmdp.nA: s_ind * fmdp.nA + fmdp.nA]
        if eps_greedy:
            probs = probs + eps / fmdp.nA
            probs = np.minimum(np.ones(fmdp.nA) - eps + eps / fmdp.nA, probs)
        a_ind = np.random.choice(range(fmdp.nA), p=probs)
        a = ind_to_x(a_ind, n, fmdp.d_A)
        x = np.concatenate((s, a))
        # mdp transition
        tran = fmdp.step(a)
        ns = tran[3]
        # save data
        data.append(tran)
        for j, mask in enumerate(masks):
            Z_j = fmdp.scopes[j]
            ind = x_to_ind(x[Z_j], n)
            mask[ind][ns[j]] = mask[ind][ns[j]] + 1
        # next state
        s = ns
    
    return (data, masks)

# factored psrl method
def factored_psrl(fmdp_star, model_prior, n_iters, eps_greedy=False, eps=1e-2):
    mdp_star = factored_to_tabular(fmdp_star)
    _, value_star = value_iteration(mdp_star)
    # main loop
    cum_regret = 0
    regret = np.zeros(n_iters)
    model_error = np.zeros(n_iters)
    value = np.zeros(n_iters)
    fmdp_t = copy.deepcopy(fmdp_star)
    for t in range(n_iters):
        # draw mdoel from prior
        factors_t = model_prior.draw()
        fmdp_t.setP(factors_t)
        mdp_t = factored_to_tabular(fmdp_t)
        # compute optimal policy
        policy_t, _ = value_iteration(mdp_t)
        # collect rollouts
        _, masks_t = collect_data(fmdp_star, policy_t, eps_greedy, eps)
        # posterior update
        model_prior.update(masks_t)
        # compute regret
        value_t = policy_evaluation(mdp_star, policy_t)
        cum_regret = cum_regret + value_star - value_t
        regret[t] = cum_regret
        model_error[t] = np.linalg.norm(mdp_star.P - mdp_t.P, 1)
        value[t] = value_t

    return policy_t, regret, model_error, value