# -*- coding utf-8 -*-
# main.py

# use the generated environment and test all algorithms

import numpy as np
import argparse

from utils import RandomSimplexVector
from Env import FiniteStateFiniteActionLinearMDP
from SafeLSVE import safeTrainer

parser = argparse.ArgumentParser(description='Run Algorithms with respect to a Baseline')
parser.add_argument('-H', type=int, help='the length of horizon')
parser.add_argument('-S', type=int, help='the number of states')
parser.add_argument('-A', type=int, help='the number of actions')
parser.add_argument('-d', type=int, help='the number of dimensions of feature')
parser.add_argument('-env', type=str, help='the name of environment')
parser.add_argument('-k', type=float, help='the k to generate baseline')
parser.add_argument('-alpha', type=float, help='the alpha used in constraint')
parser.add_argument('-beta', type=float, help='the beta of LCB')
parser.add_argument('-N', type=int, help='the number of total epochs')
parser.add_argument('-M', type=int, help='the number of total trials')
args = parser.parse_args()

env = FiniteStateFiniteActionLinearMDP(H=args.H, S=args.S, A=args.A, d=args.d)
env.load_env(args.env)

b, actions = env.best_gen()
print('optimal policy average total rewards: ', b)

b, actions = env.baseline_gen(temprature_k=args.k)
print('baseline policy average total rewards: ', b)


for j in range(args.M):
    safe_trainer = [safeTrainer(env, baseline=(b, actions), N=5000, alpha = args.alpha) for _ in range(3)]
    for i in range(3):
        safe_trainer[i].beta_lcb = args.beta 
        safe_trainer[i].beta_ucb = 1
    hist = [[] for _ in range(3)]
    h0l = [[] for _ in range(3)]
    rhol = [[] for _ in range(3)]
    for i in range(args.N):
        # 0 with rho
        r, h0, rho = safe_trainer[0].epoch_train(algo = 'StepMix')
        hist[0].append(r)
        h0l[0].append(h0)
        rhol[0].append(rho)
        # 1 without rho
        r, h0, rho = safe_trainer[1].epoch_train(algo = 'StepNoMix')
        hist[1].append(r)
        h0l[1].append(h0)
        rhol[1].append(rho)
        # 2 non-Markov
        r, h0, rho = safe_trainer[2].epoch_train(algo = 'nonmarkov')
        hist[2].append(r)
        h0l[2].append(h0)
        rhol[2].append(rho)
    np.save('Results/kk_' + str(args.k) + '_lcb_beta_' + str(args.beta) + '_alpha_' + str(args.alpha) + '_hist_' + str(j) + '.npy', hist)
    np.save('Results/kk_' + str(args.k) + '_lcb_beta_' + str(args.beta) + '_alpha_' + str(args.alpha) + '_h0_' + str(j) + '.npy', h0l)
    np.save('Results/kk_' + str(args.k) + '_lcb_beta_' + str(args.beta) + '_alpha_' + str(args.alpha) + '_rho_' + str(j) + '.npy', rhol)