import numpy as np
import matplotlib.pyplot as plt
import random
import os
from Utils_PerformativeRL import *
import pdb
import time

#MDP environment setup
start_time=time.time()
env_dict=env_setup(seed_init=1,num_states=5,num_actions=4,rho=None,\
                   transP_func=None,reward_func=None,gamma=0.95)

#Hyperparameters
lambda1=0.5
num_iters=400
stepsize_0FW=0.01  #Stepsize for our Algorithm 1
stepsize_RT=0.01    #Stepsize for repeated training

#Implement algorithms
result1=our_0FW(env_dict,lambda1,num_iters=num_iters,num_samples=1000,Delta=0.001,delta=1e-4,beta=stepsize_0FW,\
                  pi0=None,numV_iters=1000,V_iter_eps=1e-12,is_print=True,is_printV=False)

result2=repeat_train(env_dict,lambda1,outer_iters=num_iters,inner_iters=100,eta=stepsize_RT,pi0=None,\
                      is_print=True,numV_iters=1000,V_iter_eps=1e-12,is_printV=False)

total_minutes=(time.time()-start_time)/60
print("Total time consumption: "+str(total_minutes)+" minutes.")


#Save results
result_root="./Results_PRL/"
os.makedirs(result_root, exist_ok=True)

np.save(result_root+"0FW_V_unreg"+".npy",result1['V_unreg'])
np.save(result_root+"0FW_V_reg_lambda"+str(lambda1)+".npy",result1['V_reg'])
np.save(result_root+"0FW_V_entropy"+".npy",result1['Entropy'])

np.save(result_root+"RepeatedTraining_V_unreg"+".npy",result2['V_unreg'])
np.save(result_root+"RepeatedTraining_V_reg_lambda"+str(lambda1)+".npy",result2['V_reg'])
np.save(result_root+"RepeatedTraining_V_entropy"+".npy",result2['Entropy'])



#Plot results for unregularized objective
width=8
height=6
fontsize_lgd=17
fontsize_axes=17
label_size=17
num_size=17

V_unreg_0FW=result1['V_unreg']
L=len(V_unreg_0FW)
plt.figure(figsize=(width,height))
plt.plot(range(L),V_unreg_0FW,color="red",label="Our Algorithm 1")
plt.plot(range(L),result2['V_unreg'],color="black",linestyle="dashed",label="Repeated Traning Algorithm")
plt.legend(prop={'size':fontsize_lgd})
plt.xlabel("Iteration t",fontsize=fontsize_axes)
plt.ylabel("Unregularized Value "+r'$V_{0,\pi_t}^{\pi_t}$',fontsize=fontsize_axes)
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.savefig(result_root+"simulation_unreg.png",dpi=200,bbox_inches='tight')

#Plot results for regularized objective
plt.figure(figsize=(width,height))
V_reg_0FW=result1['V_reg']
plt.plot(range(L),V_reg_0FW,color="red",label="Our Algorithm 1")
plt.plot(range(L),result2['V_reg'],color="black",linestyle="dashed",label="Repeated Traning Algorithm")
plt.legend(prop={'size':fontsize_lgd})
plt.xlabel("Iteration t",fontsize=fontsize_axes)
plt.ylabel("Regularized Value "+r'$V_{\lambda,\pi_t}^{\pi_t}$',fontsize=fontsize_axes)
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.savefig(result_root+"simulation_lambda"+str(lambda1)+".png",dpi=200,bbox_inches='tight')




