###POLICY ITERATION FUNCTIONS TO COMPUTE THE EMPIRICAL OPTIMAL POLICY
import numpy as np
import itertools
# import time
def Greedy(R,gamma,P,V):
    Ns = R.shape[0]
    Na = R.shape[1]
    Q = np.zeros((Ns,Na))
#     for s,a in itertools.product(range(Ns),range(Na)):
#         Q[s,a]= R[s,a] + gamma*np.dot(P[s,a].reshape(1,-1),V)
    X = np.sum(np.matmul(P,V),axis=2)
    Q = R + gamma*X
#     print("value diff",np.max(Q,axis=1)-V)
#     print("Q", Q)
    V = np.max(Q,axis=1)
    V = V.reshape(-1,1)
    pi = np.argmax(Q,axis=1)
    return pi,V,Q
def Value(pi,R,gamma,P) :
    Ns = R.shape[0]
    P_pi = np.zeros((Ns,Ns))
    R_pi = np.zeros((Ns,1))
    for s in range(Ns):
        R_pi[s] = R[s,pi[s]]
        for s_prime in range(Ns):
            P_pi[s,s_prime] = P[s,pi[s],s_prime] 
    V = np.linalg.solve(np.eye(Ns) -gamma*P_pi,R_pi)
    return V
def policy_iteration(mdp) :
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    P = mdp.P
    R = mdp.R
    #rng = default_rng()
    pi = np.zeros((Ns,1),dtype=int) #rng.integers(0, high=Na,size=Ns)
#     tt = time.time()
    V = Value(pi,R,gamma,P)
#     print(time.time()-tt)
    t = 0
    while True :
#         t+= 1
        pi_1,V_1,Q = Greedy(R,gamma,P,V)
#         print("t=",t)
    #    print(pi,pi_1)
     #   print(V_1-V)
        if (pi_1 == pi).all() or np.abs(V_1 - V).max()< 1e-15:
#             print((pi_1 == pi).all())
            pi = pi_1
            V = V_1
#             print("Iterations",t)
            return pi, V, Q
        pi = pi_1
        V =  Value(pi,R,gamma,P)