import numpy as np
import math
from common_funcs import *
import itertools

def h_bonus_Pa(H,L_Pa,N_x_Pa,x,z,states,Pa):
    return H*math.sqrt(1.0/N_x_Pa[states.index(x),Pa.index(z)])*L_Pa


def UCB_Q_val_Pa(Hist_states,Hist_PAs,N_x_Pa_y,N_x_Pa,k,Pa,actions,states,d_s,Q,R_Pa,H,L_Pa,Cprob):
    A,S,Z = len(actions),len(states),len(Pa)
    P_hat_x_Pa = np.zeros(shape=(S,Z,S))
    V_h_x = np.zeros(shape=(H+1,S))
    # update all Ns from Hist_states: H+1 by k, Hist_PAs: H by k up to corrent episode k
    for i in range(H):
        # print(Hist_states.shape)
        x,z,y = Hist_states[:,i,k],Hist_PAs[i,k,:],Hist_states[:,i+1,k]
        z = z.tolist()
        x = x.tolist()
        y = y.tolist()
        # print(x)
        # print(states.index(x))
        N_x_Pa_y[states.index(x),Pa.index(z),states.index(y)]+=1
        N_x_Pa[states.index(x),Pa.index(z)]+=1
    for i_s1 in range(S):
        for i_Pa in range(Z):
            for i_s2 in range(S):
                if N_x_Pa[i_s1,i_Pa]>0:
                    # update estimated transition probabilities
                    P_hat_x_Pa[i_s1,i_Pa,i_s2] = 1.0*N_x_Pa_y[i_s1,i_Pa,i_s2]/N_x_Pa[i_s1,i_Pa]
    for h in range(H-1,-1,-1):
        for i_s in range(S):
            for i_Pa in range(Z):
                if N_x_Pa[i_s,i_Pa]>0:
                    b = h_bonus_Pa(H,L_Pa,N_x_Pa,states[i_s],Pa[i_Pa],states,Pa)
                    R_x_Pa = R_Pa[i_s,i_Pa]
                    PV_x_Pa = np.inner(P_hat_x_Pa[i_s,i_Pa,:],V_h_x[h+1,:])
                    # Q[k+1,h,i_s,i_Pa] = min(Q[k,h,i_s,i_Pa],H,R_x_Pa+PV_x_Pa+b)
                    Q[k+1,h,i_s,i_Pa] = min(H,R_x_Pa+PV_x_Pa+b)
                else:
                    Q[k+1,h,i_s,i_Pa] = H
        Q_MDP = np.zeros(shape=(S,A))
        for i_s in range(S):
            for i_a in range(A):
                Q_MDP[i_s,i_a] = np.inner(Cprob[i_a,i_s,:],Q[k+1,h,i_s,:])
            V_h_x[h,i_s] = max(Q_MDP[i_s,:])
    return Q

def UCBVI_PAs(K,H,L_Pa,actions,states,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob):
    reward = np.zeros(shape=(K,H))
    regret = [0]*K
    S,A,Z = len(states),len(actions),len(Pa)
    Hist_states,Hist_PAs = np.zeros(shape=(d_s,H+1,K)),np.zeros(shape=(H,K,a_dim))
    # print(Hist_states.shape)
    N_x_Pa_y = np.zeros(shape=(S,Z,S))
    N_x_Pa = np.zeros(shape=(S,Z))
    Q = H*np.ones(shape=[K,H,S,Z])
    V_star = get_V_star(H=H,R=R,actions=actions,states=states,P_tran=P_tran)
    for k in range(K):
        # print("k")
        # print(k)
        x_ind = np.random.choice(range(S))
        x = states[x_ind]
        # x = np.random.choice(states)
        V_star_init = V_star[0,states.index(x)]
        # print(x)
        Hist_states[:,0,k] = x
        for h in range(H):
            Q_MDP = [0]*A
            for i_a in range(A):
                Q_MDP[i_a] = np.inner(Cprob[i_a,states.index(x),:],Q[k,h,states.index(x),:])
            a_ind = np.argmax(Q_MDP)
            # z_ind = np.argmax(Cprob[a_ind,states.index(x),:])
            z_ind = np.random.choice(a = range(Z),size=1,p=Cprob[a_ind,states.index(x),:])[0]
            # x_next = np.random.choice(a=states,size=1,p=P_tran_Pa[states.index(x),z_ind,:])
            x_next_ind = np.random.choice(a=range(S),size=1,p=P_tran_Pa[states.index(x),z_ind,:])[0]
            x_next = states[x_next_ind]
            reward[k,h] = R[states.index(x),a_ind]
            Hist_PAs[h,k,:],Hist_states[:,h+1,k] = Pa[z_ind],x_next
            x = x_next
        if k<K-1:
            Q = UCB_Q_val_Pa(Hist_states[:,:,:(k+1)],Hist_PAs[:,:(k+1),:],N_x_Pa_y,N_x_Pa,k,Pa,actions,states,d_s,Q,R_Pa,H,L_Pa,Cprob)
        regret[k] = V_star_init - sum(reward[k,:])
    return reward,regret,Q
