from common_funcs import *
from UCBVI_funcs import *
from CUCBVI_funcs import *
import matplotlib.pyplot as plt
from FMDP_funcs import *
from F_OVI_funcs import *
import pickle

# global constants
a_dim = 3
X_max = 3
Z_max = 2
S = 2
H,K = 5,5000
T = H*K
delta = 1e-5
n_sim = 10
d_s = 3
s_vals = range(2)
# np.random.seed(1)

#fixed global variables
actions = gen_actions(k=a_dim,vals=range(X_max+1))
A = len(actions)
# L_A = math.log(5.0*A*S*T/delta)
Pa = gen_Pa(k=a_dim,vals=range(1,Z_max+1))
Z = len(Pa)
# L_Pa = math.log(5.0*Z*S*T/delta)
# L_fac_Pa = 0.1*math.log(5.0*Z*d_s*len(s_vals)*T/delta)
L_A,L_Pa,L_fac_A,L_fac_Pa = 0.8,0.5,0.5,0.2
states = gen_states(d_s=d_s,vals=s_vals)
# states = [[5*i] for i in range(S)]
print(states)

# results
overall_cum_regret_UCBVI = np.zeros(shape=(n_sim,K))
overall_cum_regret_CUCBVI = np.zeros(shape=(n_sim,K))
overall_cum_regret_fac_UCBVI = np.zeros(shape=(n_sim,K))
overall_cum_regret_fac_CUCBVI = np.zeros(shape=(n_sim,K))

for nn in range(n_sim):

    # parameters
    Cprob = gen_Z_prob(actions,states,Pa)
    # change ways to generate R_Pa: random or deterministic
    # v1 R_Pa
    # R_Pa = gen_reward(Pa=Pa,states=states)
    R_Pa = gen_fac_reward(Pa=Pa, d_s=d_s, vals=s_vals,states=states)
    R = get_all_reward(R_mat=R_Pa,Cprob=Cprob,actions=actions,states=states)
    # P_tran_Pa = gen_tran_prob_Pa(Pa=Pa,states=states)
    P_tran_Pa = gen_fac_tran_prob_Pa(Pa=Pa, states=states,d_s=d_s,vals=s_vals)
    # print(P_tran_Pa.shape)
    # print(Cprob.shape)
    P_tran = gen_tran_prob(actions=actions,states=states,P_tran_Pa=P_tran_Pa,Cprob=Cprob)
    V_star = get_V_star(H = H,R=R,actions=actions,states=states,P_tran=P_tran)

    # UCBVI
    reward_UCBVI,regret_UCBVI,Q = UCBVI_actions(K,H,L_A,actions,states,d_s,P_tran,R,a_dim)
    reward_H_UCBVI = np.sum(reward_UCBVI,axis=1).tolist()
    cum_reward_H_UCBVI = [0]*K
    cum_regret_H_UCBVI = [0]*K
    for i in range(1,K):
        cum_reward_H_UCBVI[i] = sum(reward_H_UCBVI[:i])
        cum_regret_H_UCBVI[i] = sum(regret_UCBVI[:i])
    overall_cum_regret_UCBVI[nn,:] = cum_regret_H_UCBVI

    # CUCBVI
    reward_CUCBVI,regret_CUCBVI,Q_CUCBVI = UCBVI_PAs(K,H,L_Pa,actions,states,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob)
    cum_regret_H_CUCBVI = [0]*K
    for i in range(1,K):
        cum_regret_H_CUCBVI[i] = sum(regret_CUCBVI[:i])
    overall_cum_regret_CUCBVI[nn,:] = cum_regret_H_CUCBVI

    # factored UCBVI
    reward_fac_UCBVI, regret_fac_UCBVI, Q_fac_UCBVI = UCBVI_fac_actions(K, H, L_fac_Pa, actions, states, s_vals, d_s, Pa,
                                                                       P_tran_Pa, P_tran, R_Pa, R, a_dim, Cprob)
    cum_regret_H_fac_UCBVI = [0] * K
    for i in range(1, K):
        cum_regret_H_fac_UCBVI[i] = sum(regret_fac_UCBVI[:i])
    overall_cum_regret_fac_UCBVI[nn, :] = cum_regret_H_fac_UCBVI

    # factored CUCBVI
    reward_fac_CUCBVI,regret_fac_CUCBVI,Q_fac_CUCBVI = UCBVI_fac_PAs(K,H,L_fac_Pa,actions,states,s_vals,d_s,Pa,P_tran_Pa,P_tran,R_Pa,R,a_dim,Cprob)
    cum_regret_H_fac_CUCBVI = [0]*K
    for i in range(1,K):
        cum_regret_H_fac_CUCBVI[i] = sum(regret_fac_CUCBVI[:i])
    overall_cum_regret_fac_CUCBVI[nn,:] = cum_regret_H_fac_CUCBVI
with open('results/fix_3_3_4algo/reg_UCBVI.pickle','wb') as f:
    pickle.dump(overall_cum_regret_UCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
with open('results/fix_3_3_4algo/reg_CUCBVI.pickle','wb') as f:
    pickle.dump(overall_cum_regret_CUCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
with open('results/fix_3_3_4algo/reg_fac_UCBVI.pickle','wb') as f:
    pickle.dump(overall_cum_regret_fac_UCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
with open('results/fix_3_3_4algo/reg_fac_CUCBVI.pickle','wb') as f:
    pickle.dump(overall_cum_regret_fac_CUCBVI,f,protocol=pickle.HIGHEST_PROTOCOL)
# print(sum([x<0 for x in regret_UCBVI]))
# print(sum([x<0 for x in regret_CUCBVI]))
# print(sum([x<0 for x in regret_fac_CUCBVI]))
# fig,ax = plt.subplots(1,1)
# ax.plot(range(K),cum_regret_H_UCBVI,'k-')
# ax.plot(range(K),cum_regret_H_CUCBVI,'b-')
# ax.plot(range(K),cum_regret_H_fac_CUCBVI,'r-')
# plt.show()


# print(R_Pa)
# print(Cprob)
# print(R)
