import numpy as np
import itertools
import math
from common_funcs import *
from CUCBVI_funcs import *
from FMDP_funcs import *
import matplotlib.pyplot as plt

# def gen_states(d_s,vals):
#     # assume each S_i domain has same size
#     return [list(l) for l in itertools.product(vals, repeat=d_s)]

# def gen_fac_reward(Pa,d_s,vals,states):
#     # R_raw: d_s by len(S_i) by len(Pa) array
#     R_raw = np.random.uniform(low=0, high=1.0/d_s, size=(d_s, len(vals), len(Pa)))
#     # R: len(states) by len(Pa) array
#     R = np.zeros(shape = (len(states),len(Pa)))
#     for i in range(len(states)):
#         for j in range(len(Pa)):
#             s,Z = states[i],Pa[j]
#             # s_ind: value index for each s_i in its domain dom(S_i)=vals
#             s_ind = [vals.index(x) for x in s]
#             R[i,j] = sum([R_raw[d,s_ind[d],j] for d in range(d_s)])
#     return R
#
# def gen_fac_tran_prob_Pa(Pa,states,d_s,vals):
#
#     Z,S = len(Pa),len(states)
#     P = np.zeros(shape=(S,Z,S))
#     # P_raw: P(s_i'|Z,s_i)
#     P_raw = np.zeros(shape=(d_s, len(vals),Z,len(vals)))
#     for d in range(d_s):
#         for i in range(len(vals)):
#             for j in range(Z):
#                 P_raw[d,i,j,:] = np.random.dirichlet(alpha=[1]*len(vals))
#     for i in range(S):
#         s = states[i]
#         # s_ind: vector of length d_s
#         s_ind = [vals.index(x) for x in s]
#         for j in range(Z):
#             for i_next in range(S):
#                 s_next = states[i_next]
#                 s_next_ind = [vals.index(x) for x in s_next]
#                 P[i,j,i_next] = np.prod([P_raw[d,s_ind[d],j,s_next_ind[d]] for d in range(d_s)])
#     return P

def h_bonus_fac_action(H,L_fac_A,N_x_fac,x,a,states,s_vals,d,d_s,actions):
    # N_x_fac: d_s by len(s_vals) by A
    bonus = H*math.sqrt(1.0/2/N_x_fac[d,s_vals.index(x[d]),actions.index(a)])
    if d<d_s-1:
        for j in range(d+1,d_s):
            if N_x_fac[j,s_vals.index(x[j]),actions.index(a)]>0:
                bonus+=2*H*len(s_vals)*math.sqrt(1.0/N_x_fac[d,s_vals.index(x[d]),actions.index(a)]/N_x_fac[j,s_vals.index(x[j]),actions.index(a)])
    return bonus*L_fac_A

def UCB_Q_val_fac_action(Hist_states,s_vals,Hist_actions,N_x_a_y,N_x_a,k,Pa,actions,states,d_s,Q,R,H,L_fac_A,Cprob):
    A,S = len(actions),len(states)
    P_hat_fac_x_a = np.zeros(shape=(d_s,S,A,S))
    P_hat_x_a = np.zeros(shape=(S, A, S))
    V_h_x = np.zeros(shape=(H+1,S))
    # update all Ns from Hist_states: d_s by H+1 by k, Hist_PAs: d by H by k by Pa_dim up to current episode k
    for d in range(d_s):
        for i in range(H):
            x,a,y = Hist_states[d,i,k],Hist_actions[i,k,:],Hist_states[d,i+1,k]
            a = a.tolist()
            N_x_a_y[d,s_vals.index(x),actions.index(a),s_vals.index(y)]+=1
            N_x_a[d,s_vals.index(x),actions.index(a)]+=1
    for i_s1 in range(S):
        for i_a in range(A):
            for i_s2 in range(S):
                P_hat_x_a[i_s1, i_a, i_s2] = 1
                for d in range(d_s):
                    i_s1_d = s_vals.index(states[i_s1][d])
                    i_s2_d = s_vals.index(states[i_s2][d])
                    if N_x_a[d,i_s1_d,i_a]>0:
                        # update estimated transition probabilities
                        P_hat_fac_x_a[d,i_s1_d,i_a,i_s2_d] = 1.0*N_x_a_y[d,i_s1_d,i_a,i_s2_d]/N_x_a[d,i_s1_d,i_a]
                        P_hat_x_a[i_s1,i_a,i_s2] = P_hat_x_a[i_s1,i_a,i_s2]*P_hat_fac_x_a[d,i_s1_d,i_a,i_s2_d]
    for h in range(H-1,-1,-1):
        for i_s in range(S):
            for i_a in range(A):
                b = [0]*d_s
                for d in range(d_s):
                    i_s_d = s_vals.index(states[i_s][d])
                    if N_x_a[d,i_s_d,i_a]>0:
                        b[d] = h_bonus_fac_action(H,L_fac_A,N_x_a,states[i_s],actions[i_a],states,s_vals,d,d_s,actions)
                    else:
                        Q[k+1,h,i_s,i_a] = H
                        break
                R_x_a = R[i_s, i_a]
                PV_x_a = np.inner(P_hat_x_a[i_s, i_a, :], V_h_x[h + 1, :])
                # Q[k+1,h,i_s,i_Pa] = min(Q[k,h,i_s,i_Pa],H,R_x_Pa+PV_x_Pa+b)
                Q[k+1,h,i_s,i_a] = min(H-h+1, R_x_a + PV_x_a + sum(b))
        for i_s in range(S):
            V_h_x[h,i_s] = max(Q[k+1,h,i_s,:])
    return Q

def UCBVI_fac_actions(K,H,L_fac_A,actions,states,s_vals,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob):
    reward = np.zeros(shape=(K,H))
    regret = [0]*K
    S,A = len(states),len(actions)
    Hist_states,Hist_actions = np.zeros(shape=(d_s,H+1,K)),np.zeros(shape=(H,K,a_dim))
    N_x_a_y = np.zeros(shape=(d_s,len(s_vals),A,len(s_vals)))
    N_x_a = np.zeros(shape=(d_s,len(s_vals),A))
    Q = H*np.ones(shape=[K,H,S,A])
    V_star = get_V_star(H=H,R=R,actions=actions,states=states,P_tran=P_tran)
    for k in range(K):
        # print("k")
        # print(k)
        x_ind = np.random.choice(range(S))
        # print(x_ind)
        x = states[x_ind]
        V_star_init = V_star[0,x_ind]
        # print(x)
        for d in range(d_s):
            Hist_states[d,0,k] = x[d]
        for h in range(H):
            a_ind = np.argmax(Q[k, h, states.index(x), :])
            x_next_ind = np.random.choice(a=range(S), size=1, p=P_tran[states.index(x), a_ind, :])[0]
            # x_next_ind = np.random.choice(a=range(S),size=1,p=P_tran_Pa[x_ind,a_ind,:])[0]
            x_next = states[x_next_ind]
            reward[k,h] = R[states.index(x),a_ind]
            Hist_actions[h,k,:] = actions[a_ind]
            for d in range(d_s):
                Hist_states[d,h+1,k] = x_next[d]
            x = x_next
        if k<K-1:
            Q = UCB_Q_val_fac_action(Hist_states[:,:,:(k+1)],s_vals,Hist_actions[:,:(k+1),:],N_x_a_y,N_x_a,k,Pa,actions,states,d_s,Q,R,H,L_fac_A,Cprob)
        regret[k] = V_star_init - sum(reward[k,:])
    return reward,regret,Q

