from common_funcs import *
import numpy as np
import math
import itertools

def h_bonus_action(H,L_A,N_x_a,x,a,states,actions):
    return H*math.sqrt(1.0/N_x_a[states.index(x),actions.index(a)])*L_A

def UCB_Q_val_action(Hist_states,Hist_actions,N_x_a_y,N_x_a,k,actions,states,d_s,Q,R,H,L_A):
    # print(states)
    A,S = len(actions),len(states)
    P_hat_x_a = np.zeros(shape=(S,A,S))
    V_h_x = np.zeros(shape=(H+1,S))
    # update all Ns from Hist_states: H+1 by k, Hist_actions: H by k up to corrent episode k
    for i in range(H):
        x,a,y = Hist_states[:,i,k],Hist_actions[i,k,:],Hist_states[:,i+1,k]
        a = a.tolist()
        x = x.tolist()
        y = y.tolist()
        N_x_a_y[states.index(x),actions.index(a),states.index(y)]+=1
        N_x_a[states.index(x),actions.index(a)]+=1
    for i_s1 in range(S):
        for i_a in range(A):
            for i_s2 in range(S):
                if N_x_a[i_s1,i_a]>0:
                    # update estimated transition probabilities
                    P_hat_x_a[i_s1,i_a,i_s2] = 1.0*N_x_a_y[i_s1,i_a,i_s2]/N_x_a[i_s1,i_a]
    for h in range(H-1,-1,-1):
        for i_s in range(S):
            for i_a in range(A):
                if N_x_a[i_s,i_a]>0:
                    b = h_bonus_action(H,L_A,N_x_a,states[i_s],actions[i_a],states,actions)
                    R_x_a = R[i_s,i_a]
                    PV_x_a = np.inner(P_hat_x_a[i_s,i_a,:],V_h_x[h+1,:])
                    # Q[k+1,h,i_s,i_a] = min(Q[k,h,i_s,i_a],H,R_x_a+PV_x_a+b)
                    Q[k+1,h,i_s,i_a] = min(H,R_x_a+PV_x_a+b)
                else:
                    Q[k+1,h,i_s,i_a] = H
        for i_s in range(S):
            V_h_x[h,i_s] = max(Q[k+1,h,i_s,:])
    return Q

def UCBVI_actions(K,H,L_A,actions,states,d_s,P_tran,R,a_dim):
    reward = np.zeros(shape=(K,H))
    regret = [0]*K
    S,A = len(states),len(actions)
    Hist_states,Hist_actions = np.zeros(shape=(d_s,H+1,K)),np.zeros(shape=(H,K,a_dim))
    N_x_a_y = np.zeros(shape=(S,A,S))
    N_x_a = np.zeros(shape=(S,A))
    Q = H*np.ones(shape=[K,H,S,A])
    V_star = get_V_star(H=H,R=R,actions=actions,states=states,P_tran=P_tran)
    for k in range(K):
        # print("k")
        # print(k)
        # at every episode, random sample initial state
        # x = np.random.choice(states)
        x_ind = np.random.choice(range(S))
        x = states[x_ind]
        V_star_init = V_star[0,states.index(x)]
        # print(x)
        Hist_states[:,0,k] = x
        for h in range(H):
            a_ind = np.argmax(Q[k,h,states.index(x),:])
            # x_next = np.random.choice(a=states,size=1,p=P_tran[states.index(x),a_ind,:])
            x_next_ind = np.random.choice(a=range(S), size=1, p=P_tran[states.index(x), a_ind, :])[0]
            x_next = states[x_next_ind]
            reward[k,h] = R[states.index(x),a_ind]
            Hist_actions[h,k,:],Hist_states[:,h+1,k] = actions[a_ind],x_next
            x = x_next
        if k<K-1:
            Q = UCB_Q_val_action(Hist_states[:,:,:(k+1)],Hist_actions[:,:(k+1),:],N_x_a_y,N_x_a,k,actions,states,d_s,Q,R,H,L_A)
        regret[k] = V_star_init - sum(reward[k,:])
    return reward,regret,Q

