import pickle,time,sys,os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gym
from replay_buffer import ReplayBuffer
import hashlib
import matplotlib.pyplot as plt
from dqr_policy import DQR_Policy
from dqr4_policy import DQR4_Policy

class RealOptionEnv:
    def __init__(self,data):
        self.action_space=gym.spaces.discrete.Discrete(2)
        self.observation_space=gym.spaces.box.Box(-np.inf,np.inf,shape=(2,),dtype='float64')
        self.T=100
        self.K=1
        self.s0=1
        self.data=data/data[0]

    def seed(self,s):
        pass

    def get_obs(self,s,t):
        return np.array([s,(self.T-t)/self.T])

    def reset(self,s_step=None):
        self.state_t=0
        self.state_s=self.s0
        return self.get_obs(self.state_s,self.state_t)

    def step(self,act):
        act=act%2
        if self.state_t==self.T-2:
            act=1
        if self.state_t<self.T-1:
            self.state_t+=1
        if act>0:
            r=np.maximum(0,self.K-self.state_s)
            done=True
        else:
            r=0
            done=False
        self.state_s=self.data[self.state_t]
        return self.get_obs(self.state_s,self.state_t),r,done,None
    

def eval_dynamic():
    all_res=[]
    all_res_last=[]
    for stock_i in range(10):
        sres=[]
        for i,n_cvar in enumerate(np.arange(50,105,5)):
            this_res=[]
            for seed in [1,2,3]:
                tr_env=RealOptionEnv(data[stock_i,val_st_idx:val_st_idx+100,3])
                pol=DQR_Policy(alpha,gamma,seed,tr_env.observation_space.shape[0],tr_env.action_space.n,n_quantiles,n_cvar)
                pol.model=keras.models.load_model('res/dqr_option_res{}_model_{:g}_{}_{}_{}_{}.h5'.format(stock_i,alpha,max_iter,n_quantiles,n_cvar,seed),compile=False)
                res=[]
                print('Evaluating dynamic stock {} cvar {} seed {} ...'.format(stock_i,n_cvar,seed),flush=True)
                for offset in np.arange(0,900,9):
                    tst_env=RealOptionEnv(data[stock_i,val_st_idx+offset:val_st_idx+offset+100,3])
                    r,_=pol.eval_policy(tst_env,1,max_episode_len,gamma=gamma,s_step=0)
                    res.append(r)
                res=np.sort(np.array(res).flatten())
                this_res.append([np.mean(res[:n]) for n in np.arange(50,105,5)])
            sres.append( np.mean(np.array(this_res),axis=0)[i] )
            if n_cvar==100:
                all_res_last.append(np.mean(np.array(this_res),axis=0))
        all_res.append(sres)
    return np.array(all_res),np.array(all_res_last)

def eval_static():
    all_res=[]
    for stock_i in range(10):
        sres=[]
        for i,n_cvar in enumerate(np.arange(50,100,5)):
            this_res=[]
            for seed in [1,2,3]:
                tr_env=RealOptionEnv(data[stock_i,val_st_idx:val_st_idx+100,3])
                pol=DQR4_Policy(alpha,gamma,seed,tr_env.observation_space.shape[0],tr_env.action_space.n,n_quantiles,n_cvar)
                pol.model=keras.models.load_model('res/dqr4_option_res{}_model_{:g}_{}_{}_{}_{}.h5'.format(stock_i,alpha,max_iter,n_quantiles,n_cvar,seed),compile=False)
                res=[]
                print('Evaluating static stock {} cvar {} seed {} ...'.format(stock_i,n_cvar,seed),flush=True)
                for offset in np.arange(0,900,9):
                    tst_env=RealOptionEnv(data[stock_i,val_st_idx+offset:val_st_idx+offset+100,3])
                    r,_=pol.eval_policy(tst_env,1,max_episode_len,gamma=gamma,s_step=0)
                    res.append(r)
                res=np.sort(np.array(res).flatten())
                this_res.append(np.mean(res[:n_cvar]))
            sres.append( np.mean(this_res) )
        all_res.append(sres)
    return np.array(all_res)



np.set_printoptions(5,suppress=True,linewidth=150)

### stocks=['AAPL','UNH','HD','GS','V','MCD','MSFT','BA','MMM','JNJ']
data=np.load('dow_top10_2005_2019.npy')
val_st_idx=1962 ### train / test split on Jan 1, 2016
z=np.log(data[:,1:val_st_idx,3]/data[:,:val_st_idx-1,3])
z2=np.log(data[:,val_st_idx+1:,3]/data[:,val_st_idx:-1,3])
tr_all_mean=np.mean(z,axis=1)
tr_all_std=np.std(z,axis=1)
tst_all_mean=np.mean(z2,axis=1)
tst_all_std=np.std(z2,axis=1)

max_episode_len=999
alpha=0.0001
gamma=0.999
n_quantiles=100
max_iter=200000


###gpu
gpus=tf.config.experimental.list_physical_devices('GPU')
if len(gpus)>0:
    #tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.experimental.set_virtual_device_configuration(gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])
res=eval_static()
res2,res_exp=eval_dynamic()

plt.ion()

plt.figure(2)
plt.clf()
tau_rng=np.arange(0.5,1.0,0.05)
plt.errorbar(tau_rng,np.mean(res,axis=0),np.std(res,axis=0)*2/np.sqrt(10),fmt='.-',capsize=3)
tau_rng=np.arange(0.5,1.05,0.05)
plt.errorbar(tau_rng,np.mean(res2,axis=0),np.std(res2,axis=0)*2/np.sqrt(10),fmt='.-',capsize=3)
plt.errorbar(tau_rng,np.mean(res_exp,axis=0),np.std(res_exp,axis=0)*2/np.sqrt(10),fmt='.-',capsize=3)
plt.grid()
plt.legend(['Static','Dynamic','Expectation'])
plt.xlabel('alpha level')
plt.ylabel('CVaR')

