import numpy as np
import copy

# class to instantiate the prior
class Prior():

    def __init__(self, nS, nA):
        self.nS = nS
        self.nA = nA
        self.params = np.ones(shape=(nS * nA, nS))

    def draw(self):
        nS = self.nS
        nA = self.nA
        P = np.zeros(shape=(nS * nA, nS))
        for i in range(nS * nA):
            P[i] = np.random.dirichlet(self.params[i])
        return P

    def update(self, mask):
        self.params = self.params + mask

# value iteration
def value_iteration(mdp):
    nS = mdp.nS
    nA = mdp.nA
    P = mdp.P
    R = mdp.R.reshape(nS, nA)
    T = mdp.T
    mu = mdp.mu
    # q-iteration
    Q = np.zeros(shape=(T, nS, nA))
    policy = np.zeros(shape=(T, nS, nS * nA))
    V = np.zeros(shape=(T, nS))
    for t in range(T - 1, -1, -1):
        if t == T - 1:
            Q[t] = R
        else:
            Q[t] = R + np.dot(P, np.max(Q[t + 1], axis=1)).reshape(nS, nA)
        for s in range(nS):
            a_max = np.argmax(Q[t][s])
            policy[t][s][s * nA + a_max] = 1.
        V[t] = np.dot(policy[t], Q[t].reshape(nS * nA))
    # expected return
    value = np.dot(mu, V[0])
    
    return (policy, value.item())

# policy evaluation
def policy_evaluation(mdp, policy):
    nS = mdp.nS
    nA = mdp.nA
    P = mdp.P
    R = mdp.R.reshape(nS, nA)
    T = mdp.T
    mu = mdp.mu
    # q-iteration
    Q = np.zeros(shape=(T, nS, nA))
    V = np.zeros(shape=(T, nS))
    for t in range(T - 1, 0, -1):
        if t == T - 1:
            V[t] = np.dot(policy[t], R.reshape(nS * nA))
        else:
            V[t] = np.dot(policy[t], Q[t].reshape(nS * nA))
        Q[t - 1] = R + np.dot(P, V[t]).reshape(nS, nA)
    # expected return
    V[0] = np.dot(policy[0], Q[0].reshape(nS * nA))
    value = np.dot(mu, V[0])
    
    return value.item()

# collect data
def collect_data(mdp, policy, eps_greedy=False, eps=1e-2):
    data = np.ndarray(shape=(mdp.T, 5))
    mask = np.zeros(shape=(mdp.nS * mdp.nA, mdp.nS))
    # interaction loop
    s = mdp.reset()
    for t in range(mdp.T):
        # draw action
        probs = policy[t][s][s * mdp.nA: s * mdp.nA + mdp.nA]
        if eps_greedy:
            probs = probs + eps / mdp.nA
            probs = np.minimum(np.ones(mdp.nA) - eps + eps / mdp.nA, probs)
        a = np.random.choice(range(mdp.nA), p=probs)
        # mdp transition
        tran = mdp.step(a)
        ns = tran[3]
        # save data
        data[t] = tran
        mask[s * mdp.nA + a][ns] = mask[s * mdp.nA + a][ns] + 1
        # next state
        s = ns
    
    return (data, mask)

# psrl method
def psrl(mdp_star, model_prior, n_iters, eps_greedy=False, eps=1e-2):
    policy_star, value_star = value_iteration(mdp_star)
    # main loop
    cum_regret = 0
    regret = np.zeros(n_iters)
    model_error = np.zeros(n_iters)
    value = np.zeros(n_iters)
    mdp_t = copy.deepcopy(mdp_star)
    for t in range(n_iters):
        # draw mdoel from prior
        P_t = model_prior.draw()
        mdp_t.setP(P_t)
        # compute optimal policy
        policy_t, _ = value_iteration(mdp_t)
        # collect rollouts
        _, mask_t = collect_data(mdp_star, policy_t, eps_greedy, eps)
        # posterior update
        model_prior.update(mask_t)
        # compute regret
        value_t = policy_evaluation(mdp_star, policy_t)
        cum_regret = cum_regret + value_star - value_t
        regret[t] = cum_regret
        model_error[t] = np.linalg.norm(mdp_star.P - mdp_t.P, 1)
        value[t] = value_t

    return policy_t, regret, model_error, value