import numpy as np
import copy

from envs.factored_mdp import x_to_ind, ind_to_x, factored_to_tabular

def generate_causal_prior(Z, scopes, eta):
    if eta > Z:
        causal_prior = [scope.tolist() for scope in scopes]
    else:
        causal_prior = [[] for _ in range(len(scopes))]
        for j, scope in enumerate(scopes):
            if len(scope) > eta:
                edges = np.random.choice(scope, size=eta, replace=False)
                causal_prior[j] = np.sort(edges).tolist()
            else:
                causal_prior[j] = np.sort(scope).tolist()
    return causal_prior

def compute_factset(Z, Z_j, indset):
    if len(Z_j) < Z and len(indset) > 0:
        indset = indset - set(Z_j)
        i = indset.pop()
        nZ_j = Z_j + [i]
        return [np.sort(nZ_j).tolist()] + compute_factset(Z, nZ_j, indset) + compute_factset(Z, Z_j, indset)
    else:
        return []

# class to instantiate the prior
class HierarchicalPrior():

    def __init__(self, d_S, d_A, Z, n, causal_prior):
        self.d_S = d_S
        print(d_S)
        self.Z = Z
        self.n = n
        self.causal_prior = causal_prior
        self.factorizations = [[] for _ in range(d_S)]
        self.hyperparams = [[] for _ in range(d_S)]
        self.params = [[] for _ in range(d_S)]
        for j in range(d_S):
            # compute the list of local factorizations for state feature j
            self.factorizations[j] = compute_factset(Z, causal_prior[j], set(range(d_S + d_A)))
            if len(causal_prior[j]) > 0:
                self.factorizations[j].append(causal_prior[j])
            ZZ_j = len(self.factorizations[j])
            # initialize the hyperparameters for state feature j
            self.hyperparams[j] = np.ones(ZZ_j) / ZZ_j
            # initialize the list of parameters for state feature j
            for Z_j in self.factorizations[j]:
                self.params[j].append(np.ones(shape=(n ** len(Z_j), n)))
        print(self.factorizations)

    def draw(self):
        d_S = self.d_S
        n = self.n
        factors = []
        scopes = []
        for j in range(d_S):
            ZZ_j = len(self.factorizations[j])
            z = np.random.choice(range(ZZ_j), p=self.hyperparams[j])
            Z_j = self.factorizations[j][z]
            scopes.append(Z_j)
            P_j = np.zeros(shape=(n ** len(Z_j), n))
            for i in range(n ** len(Z_j)):
                P_j[i] = np.random.dirichlet(self.params[j][z][i])
            factors.append(P_j)
        return scopes, factors

    def update(self, x, y):
        for j in range(self.d_S):
            y_j = y[j]
            ZZ_j = len(self.factorizations[j])
            for z in range(ZZ_j):
                Z_j = self.factorizations[j][z]
                ind = x_to_ind(x[Z_j], self.n)
                self.params[j][z][ind][y_j] = self.params[j][z][ind][y_j] + 1
                self.hyperparams[j][z] = self.hyperparams[j][z] * self.params[j][z][ind][y_j] / sum(self.params[j][z][ind])
            self.hyperparams[j] = self.hyperparams[j] / sum(self.hyperparams[j])


# value iteration
def value_iteration(mdp):
    nS = mdp.nS
    nA = mdp.nA
    P = mdp.P
    R = mdp.R.reshape(nS, nA)
    T = mdp.T
    mu = mdp.mu
    # q-iteration
    Q = np.zeros(shape=(T, nS, nA))
    policy = np.zeros(shape=(T, nS, nS * nA))
    V = np.zeros(shape=(T, nS))
    for t in range(T - 1, -1, -1):
        if t == T - 1:
            Q[t] = R
        else:
            Q[t] = R + np.dot(P, np.max(Q[t + 1], axis=1)).reshape(nS, nA)
        for s in range(nS):
            a_max = np.argmax(Q[t][s])
            policy[t][s][s * nA + a_max] = 1.
        V[t] = np.dot(policy[t], Q[t].reshape(nS * nA))
    # expected return
    value = np.dot(mu, V[0])
    
    return (policy, value.item())

# policy evaluation
def policy_evaluation(mdp, policy):
    nS = mdp.nS
    nA = mdp.nA
    P = mdp.P
    R = mdp.R.reshape(nS, nA)
    T = mdp.T
    mu = mdp.mu
    # q-iteration
    Q = np.zeros(shape=(T, nS, nA))
    V = np.zeros(shape=(T, nS))
    for t in range(T - 1, 0, -1):
        if t == T - 1:
            V[t] = np.dot(policy[t], R.reshape(nS * nA))
        else:
            V[t] = np.dot(policy[t], Q[t].reshape(nS * nA))
        Q[t - 1] = R + np.dot(P, V[t]).reshape(nS, nA)
    # expected return
    V[0] = np.dot(policy[0], Q[0].reshape(nS * nA))
    value = np.dot(mu, V[0])
    
    return value.item()

# collect data
def collect_data_update_prior(fmdp, prior, policy, eps_greedy=False, eps=1e-2):
    n = fmdp.n
    data = []
    s = fmdp.reset()
    for t in range(fmdp.T):
        # draw action
        s_ind = x_to_ind(s, n)
        probs = policy[t][s_ind][s_ind * fmdp.nA: s_ind * fmdp.nA + fmdp.nA]
        if eps_greedy:
            probs = probs + eps / fmdp.nA
            probs = np.minimum(np.ones(fmdp.nA) - eps + eps / fmdp.nA, probs)
        a_ind = np.random.choice(range(fmdp.nA), p=probs)
        a = ind_to_x(a_ind, n, fmdp.d_A)
        x = np.concatenate((s, a))
        # mdp transition
        tran = fmdp.step(a)
        ns = tran[3]
        # save data
        data.append(tran)
        # posterior update
        prior.update(x, ns)
        # next state
        s = ns

# causal psrl method
def causal_psrl(fmdp_star, model_prior, n_iters, eps_greedy=False, eps=1e-2):
    mdp_star = factored_to_tabular(fmdp_star)
    _, value_star = value_iteration(mdp_star)
    # main loop
    cum_regret = 0
    regret = np.zeros(n_iters)
    model_error = np.zeros(n_iters)
    value = np.zeros(n_iters)
    fmdp_t = copy.deepcopy(fmdp_star)
    for t in range(n_iters):
        # draw mdoel from prior
        scopes_t, factors_t = model_prior.draw()
        fmdp_t.setscopes(scopes_t)
        fmdp_t.setP(factors_t)
        mdp_t = factored_to_tabular(fmdp_t)
        # compute optimal policy
        policy_t, _ = value_iteration(mdp_t)
        # collect rollouts
        collect_data_update_prior(fmdp_star, model_prior, policy_t, eps_greedy, eps)
        # compute regret
        value_t = policy_evaluation(mdp_star, policy_t)
        cum_regret = cum_regret + value_star - value_t
        regret[t] = cum_regret
        model_error[t] = np.linalg.norm(mdp_star.P - mdp_t.P, 1)
        value[t] = value_t
        # check the quality of the drawn scopes
        # check = []
        # for j, scope in enumerate(scopes_t):
        #     if set(fmdp_star.scopes[j]).issubset(scope):
        #         check.append(1)
        #     else:
        #         check.append(0)
        # print('iter', t, 'check', check, 'scopes', scopes_t)

    return policy_t, regret, model_error, value