import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import torch as th
from train import TraininingModel
import sys
import argparse
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--log_dir", action="store", default="tmp")
parser.add_argument("--space", action="store", default="CarLike2")
parser.add_argument("--device", action="store", default="cuda:0")
parser.add_argument("--batch_size", action="store", default=256, type = int)
parser.add_argument("--eval_freq", action="store", default=2000, type = int)
parser.add_argument("--total_timesteps", action="store", default=10000000, type = int)
parser.add_argument("--depth", action = "store", default = 6, type = int)
parser.add_argument("--eps", action = "store", default = 0.1, type = float)
parser.add_argument("--gae_lambda", action = "store", default = 1.0, type = float)
parser.add_argument("--n_epochs", action = "store", default = 10, type = int)
parser.add_argument("--seed", action="store", default=0, type = int)
parser.add_argument("--learning_rate", action="store", default=1e-6, type = float)
parser.add_argument("--not_midpoint", action="store_true", default=False)
parser.add_argument("--alpha", action="store", default=1.0, type = float)

args = parser.parse_args()

import random
import numpy as np
random.seed(args.seed)
th.manual_seed(args.seed)
np.random.seed(args.seed)

os.makedirs(args.log_dir, exist_ok = True)
sys.stdout = open(args.log_dir+"/log.txt", "w")
sys.stderr = open(args.log_dir+"/error_log.txt", "w")

with open(args.log_dir+"/args.plk", 'wb') as f:
    pickle.dump(vars(args), f)
    
from pick_space import pick_space
space, eval_episodes = pick_space(args.space)

model = TraininingModel(space = space,
                        eval_episodes=eval_episodes,
                        log_dir = args.log_dir,
                        device = args.device,
                        batch_size = args.batch_size,
                        eval_freq = args.eval_freq,
                        total_timesteps = args.total_timesteps,
                        depth = args.depth,
                        eps = args.eps,
                        gae_lambda = args.gae_lambda,
                        n_epochs = args.n_epochs,
                        learning_rate = args.learning_rate,
                        not_midpoint = args.not_midpoint,
                        alpha = args.alpha,
)
model.learn()
