# -*- coding utf-8 -*-
# draw.py

# draw the figures from the data

import numpy as np
import argparse
import matplotlib.pyplot as plt
from Env import FiniteStateFiniteActionLinearMDP

def screen_average(r, screen = 10):
    rr = np.zeros_like(r)
    for i in range(rr.shape[1]):
        left = max(0, i - screen)
        right = min(rr.shape[1], i + screen)
        rr[:, i] = np.mean(r[:, left:right], axis=-1)
    return rr

def regret(r, best):
    rr = np.zeros_like(r)
    for i in range(rr.shape[1]):
        rr[:, i] = best * (i + 1) - np.sum(r[:, :(i+1)], axis=-1)
    return rr

parser = argparse.ArgumentParser(description='Draw the reward vs. epoch and regret vs. epoch graphs')
parser.add_argument('-H', type=int, help='the length of horizon')
parser.add_argument('-S', type=int, help='the number of states')
parser.add_argument('-A', type=int, help='the number of actions')
parser.add_argument('-d', type=int, help='the number of dimensions of feature')
parser.add_argument('-env', type=str, help='the name of environment')
parser.add_argument('-k', type=float, help='the k to generate baseline')
parser.add_argument('-alpha', type=float, help='the alpha used in constraint')
parser.add_argument('-beta', type=float, help='the beta of LCB')
parser.add_argument('-N', type=int, help='the number of total epochs')
parser.add_argument('-M', type=int, help='the number of total trials')
args = parser.parse_args()

batch = args.M
r = np.zeros([3, args.N, batch])

env = FiniteStateFiniteActionLinearMDP(H=args.H, S=args.S, A=args.A, d=args.d)
env.load_env(args.env)
best, _ = env.best_gen()

r_ucb = np.array([np.load('Results/hist_UCB_' + str(i) + '.npy') for i in range(batch)])
r_ucb = r_ucb.transpose()
mean_ucb = np.mean(r_ucb, axis=-1)

dir_str = 'kk_' + str(args.k) + '_lcb_beta_' + str(args.beta) + '_alpha_' + str(args.alpha)
b, _ = env.baseline_gen(temprature_k=args.k)

for i in range(batch):
    r[:, :, i] = np.load('Results/' + dir_str + '_hist_' + str(i) + '.npy')
mean_r = np.mean(r, axis=-1)
mean_r = np.concatenate((mean_r, mean_ucb.reshape([1, args.N])), axis=0)
rr = screen_average(mean_r, screen=20)
reg = regret(mean_r, best)
plt.figure(dpi=400)

n0 = np.sum(r[0] < ((1 - args.alpha) * b))
plt.plot(rr[0], label='StepMix, # of violation:'+str(n0), alpha = 0.8, zorder = 10, color='steelblue')
plt.plot(mean_r[0], alpha = 0.2, color='steelblue')

# n1 = np.sum(r[1] < ((1 - args.alpha) * b))
# plt.plot(rr[1], label='StepNoMix, # of violation:'+str(n1), alpha = 0.8, color='brown', zorder = 8)
# plt.plot(mean_r[1], alpha = 0.2, color='brown')

n2 = np.sum(r[2] < ((1 - args.alpha) * b))
plt.plot(rr[2], label='EpsMix, # of violation:'+str(n2), alpha = 0.8, color='mediumpurple', zorder = 9)
plt.plot(mean_r[2], alpha = 0.2, color='mediumpurple')

n3 = np.sum(r_ucb < ((1 - args.alpha) * b))
plt.plot(rr[3], label='LSVI-UCB, # of violation:'+str(n3), alpha = 0.8, color='tab:orange', zorder = 7)
plt.plot(mean_r[3], alpha = 0.2, color='tab:orange')

plt.plot(list(range(args.N)), np.ones(args.N) * b, '-.', label='baseline: '+ str(round(b, 2)), alpha = 0.5, zorder = 4)

plt.xlabel('epoch')
plt.ylabel('total reward')
plt.legend()
plt.savefig('figure/compare_algorithms_' + dir_str + '.png')

plt.figure(dpi=400)
l1, = plt.plot(reg[0], label='StepMix')
#plt.plot(reg[1], label='StepNoMix')
l2, = plt.plot(reg[2], label='EpsMix')
l3, = plt.plot(reg[3], label='LSVI-UCB')
l = [(l1, reg[0][-1]), (l2, reg[2][-1]), (l3, reg[3][-1])]
l = sorted(l, key=lambda x:-x[1])
l = [l[0][0], l[1][0], l[2][0]]
plt.xlabel('epoch')
plt.ylabel('regret')
plt.legend(handles=l)
plt.savefig('figure/compare_regret_' + dir_str + '.png')