import yaml
import argparse
import run_single_learn_phi_ope_main
import pdb
import run_single_gw_ope
import run_single_roy_ope
from utils import set_global_lam

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser()
# saving
parser.add_argument('--outfile', default = None)

# common setup
parser.add_argument('--env_name', type = str, required = True)
parser.add_argument('--env_type', type = str)
parser.add_argument('--dataset_name', type = str)
parser.add_argument('--normalize_states', default = 'false', type = str2bool)
parser.add_argument('--normalize_state_actions', default = 'false', type = str2bool)
parser.add_argument('--normalize_rewards', default = 'false', type = str2bool)
parser.add_argument('--skip_rate', default = 1, type = int)
parser.add_argument('--ope_method', type = str, default = 'fqe')
parser.add_argument('--encoder_name', type = str, default = 'identity')
parser.add_argument('--gamma', default = 0.999, type = float)
parser.add_argument('--epochs', default = 2000, type = int)
parser.add_argument('--phi_epochs', default = 50000, type = int)
parser.add_argument('--image_state', default = 'false', type = str2bool)
parser.add_argument('--beta', default = 0.1, type = float)
parser.add_argument('--aux_task', default = False, type = str2bool)
parser.add_argument('--aux_alpha', default = 0.1, type = float)
parser.add_argument('--krope_kernel', default = 'dot', type = str)
parser.add_argument('--krope_sigma', default = 1e-2, type = float)
parser.add_argument('--pie_num', type = int)
parser.add_argument('--visual_type', type = str, default = 'tsne')

# variables
parser.add_argument('--seed', default = 0, type = int)
parser.add_argument('--lam_inv', default = 1e-7, type = float)
parser.add_argument('--phi_lr', default = 1e-5, type = float)
parser.add_argument('--Q_lr', default = 1e-5, type = float)
parser.add_argument('--M_lr', default = 1e-5, type = float)
parser.add_argument('--phi_outdim', default = 10, type = int)
parser.add_argument('--rep_loss_function', default = 'huber', type = str)
parser.add_argument('--phi_num_hidden_layers', default = 2, type = int)
parser.add_argument('--phi_hidden_dim', default = 64, type = int)
parser.add_argument('--phi_act_function', default = 'relu', type = str)
parser.add_argument('--phi_norm_type', default = None, type = str)
parser.add_argument('--phi_soft_update_tau', default = 5e-3, type = float)
parser.add_argument('--phi_hard_update_freq', default = 5, type = int)
parser.add_argument('--phi_use_penultimate', default = 'false', type = str2bool)
parser.add_argument('--phi_qpie_eps', default = 0.1, type = float)
parser.add_argument('--Q_num_hidden_layers', default = 2, type = int)
parser.add_argument('--Q_hidden_dim', default = 256, type = int)
parser.add_argument('--Q_act_function', default = 'relu', type = str)
parser.add_argument('--clip_target', default = 'false', type = str2bool)
parser.add_argument('--Q_loss_function', default = 'mse', type = str)
parser.add_argument('--Q_reset_opt_freq', default = -1, type = int)
parser.add_argument('--Q_adam_beta', default = -1, type = float)
parser.add_argument('--Q_use_target_net', default = 'true', type = str2bool)
parser.add_argument('--Q_soft_update_tau', default = 5e-3, type = float)
parser.add_argument('--Q_hard_update_freq', default = int(5e3), type = int)
parser.add_argument('--Q_norm_type', default = None, type = str)
parser.add_argument('--Q_target_update_type', default = 'soft', type = str)
parser.add_argument('--bcrl_norm_selfpred', default = 'false', type = str2bool)
parser.add_argument('--bcrl_logdet', default = 0, type = float)
parser.add_argument('--mini_batch_size', default = 256, type = int)
parser.add_argument('--tr_set_fraction', default = 1., type = float)
parser.add_argument('--pw_dataset', default = False, type = str2bool)
parser.add_argument('--roy_sample_num', default = 0, type = int)
parser.add_argument('--roy_off_type', default = 'bad', type = str)
parser.add_argument('--mix_ratio', default = 1., type = float)
parser.add_argument('--batch_size', default = 500, type = int)
parser.add_argument('--bc_measure_logdet', default = 0, type = float)

# misc
parser.add_argument('--exp_name', default = 'gan', type = str)
parser.add_argument('--print_log', default = 'false', type = str2bool)

FLAGS = parser.parse_args()

with open('cfg.yaml', 'r') as file:
    config = yaml.safe_load(file)

set_global_lam(FLAGS)

if FLAGS.env_name == 'RandomMDP'\
    or FLAGS.env_name == 'Taxi'\
    or FLAGS.env_name == 'ChainMDP':
    run_single_gw_ope.main(FLAGS)
elif FLAGS.env_name == 'Roy' or FLAGS.env_name == 'Bairds':
    run_single_roy_ope.main(FLAGS)
elif FLAGS.env_name == 'Fourrooms':
    run_single_fourrooms_ope.main(FLAGS)
else:
    run_single_learn_phi_ope_main.main(FLAGS, config)
    