from common_funcs import *
from UCBVI_funcs import *
from CUCBVI_funcs import *
from FMDP_funcs import *
from F_OVI_funcs import *
import matplotlib.pyplot as plt
import pickle


# global constants
a_dim = 3
X_max_list = [2,3,4,5,6]
Z_max = 2
S = 2
H,K = 2,5000
T = H*K
delta = 1e-5
n_sim = 1
d_s = 3
s_vals = range(2)
# np.random.seed(1)
reg_UCBVI,reg_CUCBVI,reg_fac_CUCBVI,reg_fac_UCBVI = [0]*len(X_max_list),[0]*len(X_max_list),[0]*len(X_max_list),[0]*len(X_max_list)

for j in range(len(X_max_list)):

    #fixed global variables
    X_max = X_max_list[j]
    actions = gen_actions(k=a_dim, vals=range(X_max + 1))
    A = len(actions)
    # L_A = math.log(5.0*A*S*T/delta)
    Pa = gen_Pa(k=a_dim, vals=range(1, Z_max + 1))
    Z = len(Pa)
    # L_Pa = math.log(5.0*Z*S*T/delta)
    # L_fac_Pa = 0.1*math.log(5.0*Z*d_s*len(s_vals)*T/delta)
    L_A, L_Pa, L_fac_A, L_fac_Pa = 1, 0.6, 0.6, 0.2
    states = gen_states(d_s=d_s, vals=s_vals)

    # results
    overall_cum_regret_UCBVI = np.zeros(shape=(n_sim,K))
    overall_cum_regret_CUCBVI = np.zeros(shape=(n_sim,K))
    overall_cum_regret_fac_CUCBVI = np.zeros(shape=(n_sim,K))
    overall_cum_regret_fac_UCBVI = np.zeros(shape=(n_sim,K))

    for nn in range(n_sim):
        print(j,nn)
        # parameters
        Cprob = gen_Z_prob(actions,states,Pa)
        # change ways to generate R_Pa: random or deterministic
        # R_Pa = gen_reward(Pa=Pa,states=states)
        R_Pa = gen_fac_reward(Pa=Pa, d_s=d_s, vals=s_vals, states=states)
        R = get_all_reward(R_mat=R_Pa,Cprob=Cprob,actions=actions,states=states)
        # P_tran_Pa = gen_tran_prob_Pa(Pa=Pa,states=states)
        P_tran_Pa = gen_fac_tran_prob_Pa(Pa=Pa, states=states, d_s=d_s, vals=s_vals)
        P_tran = gen_tran_prob(actions=actions,states=states,P_tran_Pa=P_tran_Pa,Cprob=Cprob)
        V_star = get_V_star(H = H,R=R,actions=actions,states=states,P_tran=P_tran)

        # UCBVI
        reward_UCBVI,regret_UCBVI,Q = UCBVI_actions(K,H,L_A,actions,states,d_s,P_tran,R,a_dim)
        reward_H_UCBVI = np.sum(reward_UCBVI,axis=1).tolist()
        cum_reward_H_UCBVI = [0]*K
        cum_regret_H_UCBVI = [0]*K
        for i in range(1,K):
            cum_reward_H_UCBVI[i] = sum(reward_H_UCBVI[:i])
            cum_regret_H_UCBVI[i] = sum(regret_UCBVI[:i])
        overall_cum_regret_UCBVI[nn,:] = cum_regret_H_UCBVI

        # CUCBVI
        reward_CUCBVI,regret_CUCBVI,Q_CUCBVI = UCBVI_PAs(K,H,L_Pa,actions,states,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob)
        cum_regret_H_CUCBVI = [0]*K
        for i in range(1,K):
            cum_regret_H_CUCBVI[i] = sum(regret_CUCBVI[:i])
        overall_cum_regret_CUCBVI[nn,:] = cum_regret_H_CUCBVI

        # factored CUCBVI
        reward_fac_CUCBVI, regret_fac_CUCBVI, Q_fac_CUCBVI = UCBVI_fac_PAs(K,H,L_fac_Pa,actions,states,s_vals,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob)
        cum_regret_H_fac_CUCBVI = [0] * K
        for i in range(1, K):
            cum_regret_H_fac_CUCBVI[i] = sum(regret_fac_CUCBVI[:i])
        overall_cum_regret_fac_CUCBVI[nn, :] = cum_regret_H_fac_CUCBVI

        # factored UCBVI
        reward_fac_UCBVI, regret_fac_UCBVI, Q_fac_UCBVI = UCBVI_fac_actions(K, H, L_fac_A, actions, states, s_vals, d_s,
                                                                           Pa, P_tran_Pa, P_tran, R_Pa, R, a_dim, Cprob)
        cum_regret_H_fac_UCBVI = [0] * K
        for i in range(1, K):
            cum_regret_H_fac_UCBVI[i] = sum(regret_fac_UCBVI[:i])
        overall_cum_regret_fac_UCBVI[nn, :] = cum_regret_H_fac_UCBVI

    mean_UCBVI_regret = np.mean(overall_cum_regret_UCBVI,axis=0)
    mean_CUCBVI_regret = np.mean(overall_cum_regret_CUCBVI,axis=0)
    mean_fac_CUCBVI_regret = np.mean(overall_cum_regret_fac_CUCBVI,axis=0)
    mean_fac_UCBVI_regret = np.mean(overall_cum_regret_fac_UCBVI, axis=0)

    reg_UCBVI[j] = mean_UCBVI_regret[-1]
    reg_CUCBVI[j] = mean_CUCBVI_regret[-1]
    reg_fac_UCBVI[j] = mean_fac_UCBVI_regret[-1]
    reg_fac_CUCBVI[j] = mean_fac_CUCBVI_regret[-1]


with open('results/Xmax_H2_K5000_N10_4algo_v2/reg_UCBVI.pickle','wb') as f:
    pickle.dump(reg_UCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
with open('results/Xmax_H2_K5000_N10_4algo_v2/reg_CUCBVI.pickle','wb') as f:
    pickle.dump(reg_CUCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
with open('results/Xmax_H2_K5000_N10_4algo_v2/reg_fac_UCBVI.pickle','wb') as f:
    pickle.dump(reg_fac_UCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
with open('results/Xmax_H2_K5000_N10_4algo_v2/reg_fac_CUCBVI.pickle','wb') as f:
    pickle.dump(reg_fac_CUCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)

# fig,ax = plt.subplots(1,1)
# ax.plot(range(K),cum_regret_H_UCBVI,'k-')
# ax.plot(range(K),cum_regret_H_CUCBVI,'b-')
# plt.show()


# print(R_Pa)
# print(Cprob)
# print(R)

