import os,sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sea
import pandas as pd

def load_all(alg,n_cvar_rng):
    all_res={}
    for n_cvar in n_cvar_rng:
        this_res=[]
        for stock_i in range(10):
            env_name='option_res{}_0.0001_200000_100'.format(stock_i)
            res_path='res/{}_{}_{}'.format(alg,env_name,n_cvar)
            save_path='{}_eval.npy'.format(res_path)
            if os.path.exists(save_path):
                res=np.load(save_path)
            else: 
                res=np.array([np.load('{}_{}_eval.npy'.format(res_path,s)) for s in [1,2,3]])
            if res.shape[0]!=3 or res.shape[2]!=1000:
                raise
            this_res.append(res)
        all_res[n_cvar]=np.array(this_res)
    return all_res

def extract_res(li_res,cvar_rng):
    res=[]
    for n_cvar in cvar_rng:
        m= np.mean(li_res[n_cvar][:,:,:,:round(n_cvar*n_rep/100)],axis=3)
        res.append( np.mean( np.max(m[:,:,:],axis=2),axis=1 ) )
    return np.array(res)

def extract_res_exp(li_res,cvar_rng):
    res_exp=li_res[100]
    res=[]
    for n_cvar in cvar_rng:
        m= np.mean(res_exp[:,:,:,:round(n_cvar*n_rep/100)],axis=3)
        res.append( np.mean( np.max(m[:,:,:],axis=2),axis=1 ) )
    return np.array(res)

if __name__=='__main__':
    np.set_printoptions(5,suppress=True,linewidth=150)
    plt.ion()

    dqr_cvar_rng=np.arange(10,105,5)
    dqr4_cvar_rng=np.arange(10,100,5)
    max_cvar=105
    max_cvar4=100
    
    res_dqr=load_all('dqr',dqr_cvar_rng)
    res_dqr4d=load_all('dqr4d',dqr4_cvar_rng)

    n_rep=1000
    colors=plt.rcParams['axes.prop_cycle'].by_key()['color'] #'rgbmyck'
    specs=[(c,'*-') for c in colors]*3
    specs2=[(c,'*--') for c in colors]*3

    all_res_exp=extract_res_exp(res_dqr,dqr_cvar_rng)
    all_res_dqr=extract_res(res_dqr,dqr_cvar_rng)
    all_res_dqr4d=extract_res(res_dqr4d,dqr4_cvar_rng)
    

    fig=plt.figure(1)
    fig.set_tight_layout(True)
    plt.clf()
    pd_exp=pd.melt(pd.DataFrame(all_res_exp.T,columns=dqr_cvar_rng/100),var_name='alpha level',value_name='CVaR')
    pd_dqr=pd.melt(pd.DataFrame(all_res_dqr.T,columns=dqr_cvar_rng/100),var_name='alpha level',value_name='CVaR')
    pd_dqr4d=pd.melt(pd.DataFrame(all_res_dqr4d.T,columns=dqr4_cvar_rng/100),var_name='alpha level',value_name='CVaR')
    pd_all=pd.concat([pd_exp,pd_dqr,pd_dqr4d],keys=['Risk-neutral','Markov action-selection','Proposed algorithm'],names=['Algorithm'])
    sea.set_theme(style="darkgrid")
    sea.set(font_scale=2)
    h=sea.lineplot(data=pd_all,x='alpha level',y='CVaR',hue='Algorithm',style='Algorithm',markers=True,seed=1)
    h.legend_.set_title(None)
    
