import numpy as np
import math
import itertools

def gen_actions(k,vals):
    return [list(l) for l in itertools.product(vals,repeat = k)]

def gen_reward(Pa,states):
    # R(s,z)
    return np.random.uniform(low=0,high=1,size=(len(states),len(Pa)))

def gen_Pa(k,vals):
    return [list(l) for l in itertools.product(vals,repeat = k)]

def gen_Z_prob(actions,states,Pa):
    # P(Z|a,s)
    A,S,Z = len(actions),len(states),len(Pa)
    Cprob = np.zeros(shape=(A,S,Z))
    ## regular setting
    for i in range(A):
        for j in range(S):
            Cprob[i,j,:] = np.random.dirichlet(alpha=[1]*Z)
    return Cprob

def gen_tran_prob_Pa(Pa,states):
    # P(s'|Z,s)
    Z,S = len(Pa),len(states)
    P = np.zeros(shape=(S,Z,S))
    for i in range(S):
        for j in range(Z):
            P[i,j,:] = np.random.dirichlet(alpha=[1]*S)
    return P

def gen_tran_prob(actions,states,P_tran_Pa,Cprob):
    # P(s'|a,s)
    A,S = len(actions),len(states)
    P = np.zeros(shape=(S,A,S))
    for i_s1 in range(S):
        for i_a in range(A):
            for i_s2 in range(S):
                P[i_s1,i_a,i_s2] = np.inner(Cprob[i_a,i_s1,:],P_tran_Pa[i_s1,:,i_s2])
    return P

def get_reward(R_mat,Cprob,actions,states,x,a):
    # R_mat: n_pa by S reward matrix; Cprob: A by S by n_pa matrix
    # return R(x,a)
    a_index = actions.index(a)
    x_index = states.index(x)
    return np.inner(R_mat[x_index,:],Cprob[a_index,x_index,:])

def get_all_reward(R_mat,Cprob,actions,states):
    S,A = len(states),len(actions)
    # reward matrix
    R = np.zeros(shape=(S,A))
    for i in range(S):
        for j in range(A):
            R[i,j] = get_reward(R_mat=R_mat,Cprob=Cprob,actions=actions,states=states,x=states[i],a=actions[j])
    return R

def get_Q_star(x,a,V_prev,R,actions,states,P_tran):
    i_x,i_a = states.index(x),actions.index(a)
    return R[i_x,i_a]+np.inner(V_prev,P_tran[i_x,i_a,:])

def get_V_star(H,R,actions,states,P_tran):
    S,A = len(states),len(actions)
    V_star = np.zeros(shape=(H+1,S))
    Q_star = np.zeros(shape=(H+1,S,A))
    for i_s in range(S):
        V_star[H-1,i_s] = max(R[i_s,:])
        for i_a in range(A):
            Q_star[H-1,i_s,i_a] = R[i_s,i_a]
    for h in range(H-2,-1,-1):
        for i_s in range(S):
            for i_a in range(A):
                Q_star[h,i_s,i_a] = R[i_s,i_a]+np.inner(P_tran[i_s,i_a,:],V_star[h+1])
            V_star[h,i_s] = max(Q_star[h,i_s,:])
    return V_star



# actions = gen_actions(k=3,vals=[0,1,2,3])
# states = [1,2,3,4]
# print(gen_reward(actions,states))
# print((gen_reward(actions,states)).shape)
