import numpy as np
import itertools
import math
from common_funcs import *
from CUCBVI_funcs import *
import matplotlib.pyplot as plt

def gen_states(d_s,vals):
    # assume each S_i domain has same size
    return [list(l) for l in itertools.product(vals, repeat=d_s)]

def gen_fac_reward(Pa,d_s,vals,states):
    # R_raw: d_s by len(S_i) by len(Pa) array
    R_raw = np.random.uniform(low=0, high=1.0/d_s, size=(d_s, len(vals), len(Pa)))
    # R: len(states) by len(Pa) array
    R = np.zeros(shape = (len(states),len(Pa)))
    for i in range(len(states)):
        for j in range(len(Pa)):
            s,Z = states[i],Pa[j]
            # s_ind: value index for each s_i in its domain dom(S_i)=vals
            s_ind = [vals.index(x) for x in s]
            R[i,j] = sum([R_raw[d,s_ind[d],j] for d in range(d_s)])
    return R

def gen_fac_tran_prob_Pa(Pa,states,d_s,vals):

    Z,S = len(Pa),len(states)
    P = np.zeros(shape=(S,Z,S))
    # P_raw: P(s_i'|Z,s_i)
    P_raw = np.zeros(shape=(d_s, len(vals),Z,len(vals)))
    for d in range(d_s):
        for i in range(len(vals)):
            for j in range(Z):
                P_raw[d,i,j,:] = np.random.dirichlet(alpha=[1]*len(vals))
    for i in range(S):
        s = states[i]
        # s_ind: vector of length d_s
        s_ind = [vals.index(x) for x in s]
        for j in range(Z):
            for i_next in range(S):
                s_next = states[i_next]
                s_next_ind = [vals.index(x) for x in s_next]
                P[i,j,i_next] = np.prod([P_raw[d,s_ind[d],j,s_next_ind[d]] for d in range(d_s)])
    return P

def h_bonus_fac_Pa(H,L_fac_Pa,N_x_Pa,x,z,states,s_vals,d,Pa):
    return H*math.sqrt(1.0/N_x_Pa[d,s_vals.index(x[d]),Pa.index(z)])*L_fac_Pa

def UCB_Q_val_fac_Pa(Hist_states,s_vals,Hist_PAs,N_x_Pa_y,N_x_Pa,k,Pa,actions,states,d_s,Q,R_Pa,H,L_fac_Pa,Cprob):
    A,S,Z = len(actions),len(states),len(Pa)
    P_hat_fac_x_Pa = np.zeros(shape=(d_s,S,Z,S))
    P_hat_x_Pa = np.zeros(shape=(S, Z, S))
    V_h_x = np.zeros(shape=(H+1,S))
    # update all Ns from Hist_states: d_s by H+1 by k, Hist_PAs: d by H by k by Pa_dim up to current episode k
    for d in range(d_s):
        for i in range(H):
            x,z,y = Hist_states[d,i,k],Hist_PAs[i,k,:],Hist_states[d,i+1,k]
            z = z.tolist()
            N_x_Pa_y[d,s_vals.index(x),Pa.index(z),s_vals.index(y)]+=1
            N_x_Pa[d,s_vals.index(x),Pa.index(z)]+=1
    for i_s1 in range(S):
        for i_Pa in range(Z):
            for i_s2 in range(S):
                P_hat_x_Pa[i_s1, i_Pa, i_s2] = 1
                for d in range(d_s):
                    i_s1_d = s_vals.index(states[i_s1][d])
                    i_s2_d = s_vals.index(states[i_s2][d])
                    if N_x_Pa[d,i_s1_d,i_Pa]>0:
                        # update estimated transition probabilities
                        P_hat_fac_x_Pa[d,i_s1_d,i_Pa,i_s2_d] = 1.0*N_x_Pa_y[d,i_s1_d,i_Pa,i_s2_d]/N_x_Pa[d,i_s1_d,i_Pa]
                        P_hat_x_Pa[i_s1,i_Pa,i_s2] = P_hat_x_Pa[i_s1,i_Pa,i_s2]*P_hat_fac_x_Pa[d,i_s1_d,i_Pa,i_s2_d]
    for h in range(H-1,-1,-1):
        for i_s in range(S):
            for i_Pa in range(Z):
                b = [0]*d_s
                for d in range(d_s):
                    i_s_d = s_vals.index(states[i_s][d])
                    if N_x_Pa[d,i_s_d,i_Pa]>0:
                        b[d] = h_bonus_fac_Pa(H,L_fac_Pa,N_x_Pa,states[i_s],Pa[i_Pa],states,s_vals,d,Pa)
                    else:
                        Q[k+1,h,i_s,i_Pa] = H
                        break
                R_x_Pa = R_Pa[i_s, i_Pa]
                PV_x_Pa = np.inner(P_hat_x_Pa[i_s, i_Pa, :], V_h_x[h + 1, :])
                Q[k+1,h,i_s,i_Pa] = min(H, R_x_Pa + PV_x_Pa + sum(b))
        Q_MDP = np.zeros(shape=(S,A))
        for i_s in range(S):
            for i_a in range(A):
                Q_MDP[i_s,i_a] = np.inner(Cprob[i_a,i_s,:],Q[k+1,h,i_s,:])
            V_h_x[h,i_s] = max(Q_MDP[i_s,:])
    return Q

def UCBVI_fac_PAs(K,H,L_fac_Pa,actions,states,s_vals,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob):
    reward = np.zeros(shape=(K,H))
    regret = [0]*K
    S,A,Z = len(states),len(actions),len(Pa)
    Hist_states,Hist_PAs = np.zeros(shape=(d_s,H+1,K)),np.zeros(shape=(H,K,a_dim))
    N_x_Pa_y = np.zeros(shape=(d_s,len(s_vals),Z,len(s_vals)))
    N_x_Pa = np.zeros(shape=(d_s,len(s_vals),Z))
    Q = H*np.ones(shape=[K,H,S,Z])
    V_star = get_V_star(H=H,R=R,actions=actions,states=states,P_tran=P_tran)
    for k in range(K):
        # print("k")
        # print(k)
        x_ind = np.random.choice(range(S))
        # print(x_ind)
        x = states[x_ind]
        V_star_init = V_star[0,x_ind]
        # print(x)
        for d in range(d_s):
            Hist_states[d,0,k] = x[d]
        for h in range(H):
            Q_MDP = [0]*A
            for i_a in range(A):
                Q_MDP[i_a] = np.inner(Cprob[i_a,states.index(x),:],Q[k,h,states.index(x),:])
            a_ind = np.argmax(Q_MDP)
            z_ind = np.random.choice(a = range(Z),size=1,p=Cprob[a_ind,states.index(x),:])[0]
            x_next_ind = np.random.choice(a=range(S),size=1,p=P_tran_Pa[states.index(x),z_ind,:])[0]
            x_next = states[x_next_ind]
            reward[k,h] = R[states.index(x),a_ind]
            Hist_PAs[h,k,:] = Pa[z_ind]
            for d in range(d_s):
                Hist_states[d,h+1,k] = x_next[d]
            x = x_next
        if k<K-1:
            Q = UCB_Q_val_fac_Pa(Hist_states[:,:,:(k+1)],s_vals,Hist_PAs[:,:(k+1),:],N_x_Pa_y,N_x_Pa,k,Pa,actions,states,d_s,Q,R_Pa,H,L_fac_Pa,Cprob)
        regret[k] = V_star_init - sum(reward[k,:])
    return reward,regret,Q

# # global constants
# a_dim = 2
# X_max = 3
# Z_max = 2
# # S = 2
# H,K = 2,10000
# T = H*K
# delta = 1e-5
# n_sim = 5
# d_s = 2
# s_vals = range(2)
# # np.random.seed(1)
#
# #fixed global variables
# actions = gen_actions(k=a_dim,vals=range(X_max+1))
# A = len(actions)
# states = gen_states(d_s=d_s,vals=s_vals)
# S = len(states)
# print('states')
# print(states)
# L_A = math.log(5.0*A*S*T/delta)
# Pa = gen_Pa(k=a_dim,vals=range(1,Z_max+1))
# Z = len(Pa)
# # L_Pa = math.log(5.0*Z*S*T/delta)
# # L_fac_Pa = math.log(5.0*Z*d_s*len(s_vals)*T/delta)
# L_Pa,L_fac_Pa = 1,1
# # R = gen_fac_reward(Pa,d_s,s_vals,states)
# print('S,Z,A')
# print(S,Z,A)
#
# Cprob = gen_Z_prob(actions,states,Pa)
# R_Pa = gen_fac_reward(Pa,d_s,s_vals,states)
# print('R_Pa')
# print(R_Pa.shape)
# R = get_all_reward(R_mat=R_Pa,Cprob=Cprob,actions=actions,states=states)
# print('R')
# print(R.shape)
# P_tran_Pa = gen_fac_tran_prob_Pa(Pa,states,d_s,s_vals)
# print('P_tran_Pa')
# print(P_tran_Pa.shape)
# print(np.sum(P_tran_Pa[0,0,:]))
# P_tran = gen_tran_prob(actions=actions,states=states,P_tran_Pa=P_tran_Pa,Cprob=Cprob)
# print('P_tran')
# print(P_tran.shape)
# print(np.sum(P_tran[0,0,:]))
#
# reward_fac_CUCBVI,regret_fac_CUCBVI,Q_fac_CUCBVI = UCBVI_fac_PAs(K,H,L_fac_Pa,actions,states,s_vals,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob)
# reward_CUCBVI,regret_CUCBVI,Q_CUCBVI = UCBVI_PAs(K,H,L_Pa,actions,states,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob)
# # # print(regret_fac_CUCBVI)
# cum_regret_H_fac_CUCBVI = [0]*K
# cum_regret_H_CUCBVI = [0]*K
# for i in range(1,K):
#     cum_regret_H_fac_CUCBVI[i] = sum(regret_fac_CUCBVI[:i])
#     cum_regret_H_CUCBVI[i] = sum(regret_CUCBVI[:i])
# #
# fig,ax = plt.subplots(1,1)
# ax.plot(range(K),cum_regret_H_fac_CUCBVI,'b-')
# ax.plot(range(K),cum_regret_H_CUCBVI,'k-')
# plt.show()
