import gym
import argparse
import yaml
import os

from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.logger import configure
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
import torch

import modified_model
import optimizers
import torch
import time

debug = False

def load_default_hyperparams(model_name, env_name):
    cur_path = os.getcwd()
    yml_path = os.path.join(cur_path, 'hyperparams', '{}.yml'.format(model_name))
    with open(yml_path, 'r') as file:
        hyperparam_list = yaml.load(file, Loader=yaml.FullLoader)
    return hyperparam_list[env_name]

def create_dir(model_name, env_name, id=None, debug=False):
    if debug:
        env_model_path = os.path.join(os.getcwd(), 'output', 'debug', env_name, model_name)
    else:
        env_model_path = os.path.join(os.getcwd(), 'output', env_name, model_name)
    os.makedirs(env_model_path, exist_ok=True)
    
    if id is not None:
        exp_name = 'exp_{}'.format(id)
    else:
        id = 0
        # exp_name = 'exp_{}'.format(id)
        while 'exp_{}'.format(id) in os.listdir(env_model_path):
            id += 1
        exp_name = 'exp_{}'.format(id)
    
    exp_path = os.path.join(env_model_path, exp_name)
    os.makedirs(exp_path, exist_ok=True)
    print('Experiment ID: {}'.format(id))
    return exp_path


parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=['Adam', 'LKTD'])
parser.add_argument('--env', type=str, default='CartPole-v1')
parser.add_argument('--id', type=int, default=None)
parser.add_argument('--dir', type=str, default=None)
parser.add_argument('--exploration_final_eps', type=float, default=None)
parser.add_argument('--exploration_fraction', type=float, default=None)


parser.add_argument('--threshold', type=float, default=None)

parser.add_argument('--learning_rate', type=float, default=None, help="Learning rate")
parser.add_argument('--sgld_temperature',  type=float, default=None, help="SGLD temperature")
parser.add_argument('--prior_sd', type=float, default=None, help="prior sd")
parser.add_argument('--obs_sd', type=float, default=None, help="observation sd")
parser.add_argument('--alpha', type=float, default=None, help="variance split")
parser.add_argument('--sparse_sd', type=float, default=None, help="sparse sd")
parser.add_argument('--sparse_ratio', type=float, default=None, help="sparse ratio")


args = parser.parse_args()
args_dict = vars(args)



hyperparam = load_default_hyperparams(model_name=args.model, env_name=args.env)
    
if args.learning_rate is not None:
    hyperparam['learning_rate'] = args.learning_rate
if args.exploration_final_eps is not None:
    hyperparam['exploration_final_eps'] = args.exploration_final_eps
if args.exploration_fraction is not None:
    hyperparam['exploration_fraction'] = args.exploration_fraction

if args.model == 'LKTD':
    if args.sgld_temperature is not None:
        hyperparam['LKTD_kwargs']['sgld_temperature'] = args.sgld_temperature 
    if args.prior_sd is not None:
        hyperparam['LKTD_kwargs']['prior_sd'] = args.prior_sd
    if args.obs_sd is not None:
        hyperparam['LKTD_kwargs']['obs_sd'] = args.obs_sd
    if args.alpha is not None:
        hyperparam['LKTD_kwargs']['alpha'] = args.alpha
    if args.sparse_sd is not None:
        hyperparam['LKTD_kwargs']['sparse_sd'] = args.sparse_sd
    if args.sparse_ratio is not None:
        hyperparam['LKTD_kwargs']['sparse_ratio'] = args.sparse_ratio
        
    hyperparam['policy_kwargs']['optimizer_class'] = optimizers.LKTD
    hyperparam['policy_kwargs']['optimizer_kwargs'] = hyperparam.pop('LKTD_kwargs')
if 'n_timesteps' in hyperparam:
    total_timesteps = hyperparam.pop('n_timesteps')



if args.dir is not None:
    output_dir = create_dir(args.dir, args.env, args.id, debug)
else:
    output_dir = create_dir(args.model, args.env, args.id, debug)

with open(os.path.join(output_dir, 'hyperparams.txt'), 'w') as f:
    for param in hyperparam:
        f.write('{}: {} \n'.format(param, hyperparam[param]))
        
train_env = gym.make(args.env)
eval_env = gym.make(args.env)

if args.threshold is not None:
    callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=args.threshold)  
else:
    callback_on_best = None
eval_callback = EvalCallback(eval_env, 
                             callback_on_new_best=callback_on_best,
                            #  best_model_save_path=os.path.join(output_dir, 'log'),
                            #  log_path=os.path.join(output_dir, 'log'), 
                             eval_freq=int(total_timesteps/100),
                             n_eval_episodes=10,
                             deterministic=True, 
                             render=False)


# model_class = getattr(modified_model, 'DQN_{}'.format(args.model))
if args.model == 'LKTD':
    model_class = modified_model.DQN_LKTD
elif args.model == 'Adam':
    model_class = DQN
    
model = model_class(env=train_env, 
                    **hyperparam, 
                    create_eval_env=True)

# Set new logger
new_logger = configure(output_dir, ["csv", "tensorboard"])
model.set_logger(new_logger)
# Train the agent
model.learn(total_timesteps=int(total_timesteps), 
            callback=eval_callback)



# Save the agent
# model.save(os.path.join(output_dir,args.env))
