# -*- coding utf-8 -*-
# main.py

# use the generated environment and test all algorithms

import pickle
import numpy as np
import argparse
import random

from utils import RandomSimplexVector
from Env import FiniteStateFiniteActionLinearMDP
from SafeLSVE import safeTrainer
from Offline import offlineTrainer

parser = argparse.ArgumentParser(description='Run Algorithms with respect to a Baseline')
parser.add_argument('-H', type=int, help='the length of horizon')
parser.add_argument('-S', type=int, help='the number of states')
parser.add_argument('-A', type=int, help='the number of actions')
parser.add_argument('-d', type=int, help='the number of dimensions of feature')
parser.add_argument('-env', type=str, help='the name of environment')
parser.add_argument('-alpha', type=float, help='the alpha used in constraint')
parser.add_argument('-lbeta', type=float, help='the beta of LCB')
parser.add_argument('-ubeta', type=float, help='the beta of UCB')
parser.add_argument('-N1', type=int, help='the number of offline trajectories')
parser.add_argument('-N2', type=int, help='the number of online episodes')
parser.add_argument('-k', type=float, help='the temperature of boltzman baseline')
parser.add_argument('-M', type=int, help='the number of trials')
args = parser.parse_args()

env = FiniteStateFiniteActionLinearMDP(H=args.H, S=args.S, A=args.A, d=args.d)
env.load_env(args.env)

b, actions = env.baseline_gen(temprature_k=args.k)
print('k=', args.k)
print('boltzmann policy with temprature k average total rewards: ', b)

for seed in range(args.M):
    random.seed(seed)
    np.random.seed(seed)
    offline = offlineTrainer(env, beta=args.lbeta)
    offline.collect_trajectories(actions, args.N1)
    if args.N1 > 0:
        safe_trainer = [safeTrainer(env, baseline=(b, offline.LCB_policy()), N=args.N2, alpha=args.alpha, beta=(args.ubeta, args.lbeta), 
                                    trajectories=(offline.trajectory_s, offline.trajectory_a, offline.trajectory_r)) for _ in range(2)]
        safe_trainer.append(safeTrainer(env, baseline=(b, actions), N=args.N2, alpha=args.alpha, beta=(args.ubeta, args.lbeta)))
    else:
        safe_trainer = [safeTrainer(env, baseline=(b, actions), N=args.N2, alpha=args.alpha, beta=(args.ubeta, args.lbeta)) for _ in range(3)]
    hist_expectation = [[], [], []]
    hist_random = [[], [], []]
    h0l = [[], [], []]
    rhol = [[], [], []]

    for i in range(args.N2):
        r, h0, rho = safe_trainer[0].epoch_train('withrho')
        hist_expectation[0].append(r[0])
        hist_random[0].append(r[1])
        h0l[0].append(h0)
        rhol[0].append(rho)
        r, h0, rho = safe_trainer[1].epoch_train('nonmarkov')
        hist_expectation[1].append(r[0])
        hist_random[1].append(r[1])
        h0l[1].append(h0)
        rhol[1].append(rho)
        r = safe_trainer[2].epoch_train('UCB')
        hist_expectation[2].append(r[0])
        hist_random[2].append(r[1])

    with open('Results/SafeRLwithOffline_env' + args.env + '_alpha' + str(args.alpha) + '_Lbeta' + str(args.lbeta) + '_Ubeta' + str(args.ubeta) + '_N1' + str(args.N1) + '_N2' + str(args.N2) + '_tempk' + str(args.k) +  '_seed' + str(seed) + '.pkl', 'wb+') as f:
        pickle.dump({'hist_expectation': hist_expectation,
                     'hist_random':      hist_random,
                     'h0_list':          h0l,
                     'rho_list':         rhol}, f)
