import argparse
import os, sys
import datetime
import random
import pandas as pd
import numpy as np
from env.env import *
from algorithms.lsvi_phe import *
from algorithms.lsvi_ucb import *
from algorithms.uc_matrixrl import *
from algorithms.ucrl_vtr import *
from algorithms.optimal_policy import *
from algorithms.uc_hrl_naive import *
from algorithms.uc_hrl import *

import matplotlib.pyplot as plt
import seaborn as sns

def moving_avg(reward_list, window_size):
    window_size = window_size
    mv_avg_list = []
    for i, r in enumerate(reward_list):
        left = max(i-window_size+1, 0)        
        window = reward_list[left:i+1]
        window_average = round(np.mean(window), 2)
        mv_avg_list.append(window_average)
    return mv_avg_list

def Agent(agent_name, env, K, c=1e-3):
    if agent_name == "LSVI_UCB":
        agent = LSVI_UCB(env, K, c=c)
    elif agent_name == "UC_MatrixRL":
        agent = UC_MatrixRL(K, env, c=c, lam=1)
    elif agent_name == "LSVI_PHE":
        agent = LSVI_PHE(env, K, M=5, sigma=2)
    elif agent_name == "UCRL_VTR":
        agent = UCRL_VTR(env, K, c=c)
    elif agent_name == "UC_HRL_naive":
        agent = UC_HRL_naive(env, K, c=c, lam=1)
    elif agent_name == "UC_HRL":
        agent = UC_HRL(env, K, c=c, lam=1)
    elif agent_name == "Optimal":
        agent = Optimal_Policy(env, K)
    return agent

parser = argparse.ArgumentParser(description='HRL')
parser.add_argument('--max-episodes', type=int, default=int(100), metavar='EPISODES', help='Number of training episodes')
parser.add_argument('--n_states', type=int, default=5, help='Number of States')
parser.add_argument('--horizon_length', type=int, default=12, help='Horizon Length')
parser.add_argument('--c', type=float, default=5e-3, help='Confidence Constant')
parser.add_argument('--record-result', type=bool, default=True, help='Record the result or not')

# Setup
args = parser.parse_args()

K = args.max_episodes
nState = args.n_states
H = args.horizon_length
env = block_make_riverSwim(epLen=H, nState=nState)
runs = 10
seeds = [41*(i+1) for i in range(runs)]

# Save directory
if args.record_result:
    time_str = datetime.datetime.now().strftime('%Y%m%dT%H%M%S.%f')
    out_dir = f'./results/S={nState},H={H}/{time_str}/' 
    if os.path.exists(out_dir):
        raise RuntimeError('{} exists'.format(out_dir))
    else:
        os.makedirs(out_dir)

agent_list = ["LSVI_UCB", "UC_MatrixRL", "LSVI_PHE", "UCRL_VTR", "UC_HRL_naive", "UC_HRL", "Optimal"] 

results = pd.DataFrame(columns = ['Agent', 'Episode', 'Return'])

for agent_name in agent_list:
    for run in range(runs):        
        random.seed(seeds[run])
        agent = Agent(agent_name, env, K, args.c)
        episodic_return = agent.run()
        mv_avg_return = moving_avg(episodic_return, 20)

        for k, r in enumerate(mv_avg_return):
            results = results.append([{'Agent': agent_name, 'Episode': k, 'Return': r}], ignore_index=True)

if args.record_result:
    results.to_csv(out_dir+'result.csv', index=False)
