# import numpy as np
import jax.numpy as np
import numpy as onp
# RUNNING SHAPE ASSUMPTIONS
# P.shape = (n_states, n_actions, n_states)
# r.shape = (n_states, n_actions)
# pi.shape = (n_states, n_actions)

def get_V(P, r, gamma, pi):
    # assert np.sum(pi.sum(1) - 1.) == 0, np.sum(pi.sum(1) - 1.)
    P_pi = get_P_pi(P, pi)
    r_pi = np.sum(r * pi, axis = -1)
    # print(P_pi, r_pi)
    # print(np.linalg.inv(np.eye(pi.shape[0]) - gamma * P_pi) @ r_pi )
    return np.linalg.inv(np.eye(pi.shape[0]) - gamma * P_pi) @ r_pi 

def get_batch_V(P, r, gamma, pis):
    # assume pis is of shape (batch, n_states, n_actions)
    P_pis = np.sum(P[np.newaxis, :, :, :] * pis[:, :, :, np.newaxis], axis = 2)
    r_pis = np.sum(r[np.newaxis, :, :] * pis, axis = 2)
    n_states = pis.shape[1]
    inv = np.linalg.inv(np.eye(n_states)[np.newaxis, :, :] - gamma * P_pis)
    return np.sum(inv * r_pis[:, np.newaxis, :], axis = 2)

def get_batch_Q(P, r, gamma, pis):
    Vp = get_batch_V(P, r, gamma, pis)[:, np.newaxis, np.newaxis, :]
    EVp = np.sum(P[np.newaxis, :, :, :] * Vp, axis = -1)
    return r[np.newaxis, :, :] + gamma * EVp

def get_soft_V(P, r, gamma, tau, pi):
    P_pi = get_P_pi(P, pi)
    r_pi = np.sum(r * pi, axis = -1) - tau * np.sum(pi * np.log(pi), axis = 1)
    return np.linalg.inv(np.eye(pi.shape[0]) - gamma * P_pi) @ r_pi 
    

def get_Q(P, r, gamma, pi):
    n_states = P.shape[0]
    V = get_V(P, r, gamma, pi)
    Vp =  np.sum(P * V.reshape((1, 1, n_states)), axis = -1)
    return r + gamma * Vp

def get_soft_Q(P, r, gamma, tau, pi):
    if tau == 0:
        return get_Q(P, r, gamma, pi)
    else:
        n_states = P.shape[0]
        V = get_soft_V(P, r, gamma, tau, pi)
        Vp =  np.sum(P * V.reshape((1, 1, n_states)), axis = -1)
        return r + gamma * Vp
    
def value_iteration(P, r, gamma, V, iters = 1):
    n_states = P.shape[0]
    res = V
    for _ in range(iters):
        Vp =  np.sum(P * res.reshape((1, 1, n_states)), axis = -1)
        res = np.max(r + gamma * Vp, axis = -1)
        assert np.shape(res)[0] == n_states
        assert len(res[res != res]) == 0
    return res

def get_optimal_V(P, r, gamma, iters = int(1e5)):
    return value_iteration(P, r, gamma, onp.zeros(P.shape[0]), iters)


def action_value_iteration(P, r, gamma, Q, iters = 1):
    n_states = P.shape[0]
    n_actions = r.shape[1]
    res = Q
    for _ in range(iters):
        # print(res)
        max_Q = np.max(res, axis = -1).reshape((1, 1, n_states))
        res = r + gamma * np.sum(max_Q * P, axis = -1)
        assert res.shape == (n_states, n_actions), res.shape
    return res

def get_optimal_Q(P, r, gamma, iters = int(1e5)):
    n_states = P.shape[0]
    n_actions = r.shape[1]
    Q = np.zeros((n_states, n_actions))
    return action_value_iteration(P, r, gamma, Q, iters = iters)

def get_greedy_pi(Q):
    n_states = Q.shape[0]
    n_actions = Q.shape[1]
    # extract greedy policy from Q
    max_actions =  np.argmax(Q, axis = -1)
    pi = onp.zeros((n_states, n_actions))
    for i in range(n_states):
        pi[i, max_actions[i]] = 1
    return pi

def get_P_pi(P, pi):
    return np.sum(P * np.expand_dims(pi, axis = -1), axis = 1)


def get_d_pi(P, gamma, pi, rho):
    # the unnormalized version 
    # d_pi = (I - \gamma P.T)^{-1} \mu
    n_states = P.shape[0]
    P_pi = get_P_pi(P, pi)
    return np.linalg.solve(np.eye(n_states) - gamma * P_pi.T, rho)

def get_stationary_d(P, pi):
    # want d such that d = d P
    P_pi = np.sum(P * pi[:, :, np.newaxis], axis = 1)
    vals, vecs = np.linalg.eig(P_pi.T)
    ix = np.where(np.abs(vals - 1) < 1e-5)[0][0]
    d = vecs[:, ix]
    return d / d.sum()
