import pickle,time,sys,os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gym
from replay_buffer import ReplayBuffer
import hashlib
import matplotlib.pyplot as plt
from option_env import OptionEnv
from dqr4_policy import DQR4_Policy
    

def run_train():
    ###gpu
    gpus=tf.config.experimental.list_physical_devices('GPU')
    if len(gpus)>0:
        #tf.config.experimental.set_memory_growth(gpus[0], True)
        tf.config.experimental.set_virtual_device_configuration(gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])
    if not os.path.exists('./res'):
        os.mkdir('./res')
    for stock_i in range(10):
        for seed in [1,2,3]: #[int(sys.argv[1])]:
            for n_cvar in [50,55,60,65,70,75,80,85,90,95]: #[int(sys.argv[2])]:
                tr_env=OptionEnv(tr_all_mean[stock_i],tr_all_std[stock_i])
                val_env=OptionEnv(tr_all_mean[stock_i],tr_all_std[stock_i])
                tst_env=OptionEnv(tr_all_mean[stock_i],tr_all_std[stock_i])
                
                pol=DQR4_Policy(alpha,gamma,seed,tr_env.observation_space.shape[0],tr_env.action_space.n,n_quantiles,n_cvar)
                res=pol.train(epsilon,buffer_size,batch_size,seed,tr_env,val_env,tst_env,max_iter,explore_iter,eval_freq,max_episode_len)
                np.save('res/dqr4_option_res{}_{:g}_{}_{}_{}_{}.npy'.format(stock_i,alpha,max_iter,n_quantiles,n_cvar,seed),res)
                pol.model.save('res/dqr4_option_res{}_model_{:g}_{}_{}_{}_{}.h5'.format(stock_i,alpha,max_iter,n_quantiles,n_cvar,seed))
                tst_env.seed(100)
                st=time.time()
                r,_=pol.eval_policy(tst_env,n_rep,max_episode_len,s_step=0)
                print('eval {}: {:.1f} secs'.format(n_rep,time.time()-st))
                tst_env.seed(100)
                st=time.time()
                r2,_=pol.eval_policy(tst_env,n_rep,max_episode_len,gamma=gamma,s_step=0)
                print('eval2 {}: {:.1f} secs'.format(n_rep,time.time()-st))
                np.save('res/dqr4_option_res{}_eval_{:g}_{}_{}_{}_{}_{}.npy'.format(stock_i,alpha,max_iter,n_quantiles,n_cvar,n_rep,seed), np.array([r,r2]))




np.set_printoptions(5,suppress=True,linewidth=150)

### stocks=['AAPL','UNH','HD','GS','V','MCD','MSFT','BA','MMM','JNJ']
data=np.load('dow_top10_2005_2019.npy')
val_st_idx=1962 ### train / test split on Jan 1, 2016
z=np.log(data[:,1:val_st_idx,3]/data[:,:val_st_idx-1,3])
z2=np.log(data[:,val_st_idx+1:,3]/data[:,val_st_idx:-1,3])
tr_all_mean=np.mean(z,axis=1)
tr_all_std=np.std(z,axis=1)
tst_all_mean=np.mean(z2,axis=1)
tst_all_std=np.std(z2,axis=1)

max_episode_len=999

alpha=0.0001
gamma=0.999 #discount
epsilon=0.02 #exploration
buffer_size=50000
batch_size=32
n_quantiles=100

max_iter=200000
explore_iter=20000
eval_freq=1000
n_rep=1000

##### To train and save outputs
run_train()

