import tensorflow as tf
from datetime import datetime
import pathlib
import os
from utils.utils_tools import AttrDict
from config.lompo_config import build_lompo_config
from config.oraac_config import build_oraac_config
from config.combo_config import build_combo_config
from config.lodac_config import build_lodac_config

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'        


def define_config(args):
    conf = AttrDict()
    if args.risky_env:
        filename_datadir = 'Datasets/Risky'
        filename_logdir = 'logdir/Risky'
        filename_trained_latent = 'trained_latent_models/Risky'
        filename_videos = 'Videos/Risky'
    else:
        filename_datadir = 'Datasets/Risk_free'
        filename_logdir = 'logdir/Risk_free'
        filename_trained_latent = 'trained_latent_models/Risk_free'
        filename_videos = 'Videos/Risk_free'
    filename_datadir = filename_datadir + '/' + args.env + '/' + args.dataset_type
    filename_logdir = filename_logdir + '/' + args.env + '/' + args.dataset_type + '/' + args.latent_algo
    filename_trained_latent = filename_trained_latent + '/' + args.env + '/' + args.dataset_type
    filename_videos = filename_videos + '/' + args.env + '/' + args.dataset_type + '/' + args.latent_algo
    conf.datadir = pathlib.Path(filename_datadir)
    conf.logdir = pathlib.Path(filename_logdir)
    conf.trained_latent_dir = pathlib.Path(filename_trained_latent)
    conf.videos_dir = pathlib.Path(filename_videos)
    conf.latent_algo = args.latent_algo

    # parameters for latent model
    conf.stoch = 64
    conf.deter = 256
    conf.num_models = 7
    conf.shape_action = 6
    conf.max_action = 1

    conf.num_train_step_latent_model = 25000
    conf.test_every = 10000
    conf.amount_action_noise = 0.2
    conf.noise_type = 'additive_gaussian'
    conf.lmbd = 5
    conf.horizon = 5

    if args.latent_algo == 'lompo':
        conf = build_lompo_config(conf, env_name=args.env, dataset_type=args.dataset_type, risky=args.risky_env)

    elif args.latent_algo == 'oraac':
        conf = build_oraac_config(conf, env_name=args.env, dataset_type=args.dataset_type, risky=args.risky_env,
                                  load_imit_agent=args.load_imit_actor)

    elif args.latent_algo == 'combo':
        conf = build_combo_config(conf, env_name=args.env, dataset_type=args.dataset_type, risky=args.risky_env)

    elif args.latent_algo == 'lodac':
        conf = build_lodac_config(conf, env_name=args.env, dataset_type=args.dataset_type, risky=args.risky_env)

    else:
        raise NotImplementedError(args.latent_algo)

    # environment
    conf.task = args.env
    conf.action_repeat = 2
    conf.time_limit = 1000
    conf.precision = 32

    # risky behaviour of the environment
    conf.risky = args.risky_env
    conf.penal = 8
    conf.prob = 0.1
    # since in our code, we are working with the cumulative reward, we should consider (1 - alpha) instead of alpha
    # -> conf.alpha_cvar of 0.3 corresponds to 0.7 in the paper
    conf.alpha_cvar = 0.3
    """
        We choose this parameter such that about half the states in the train dataset are risky ...
        - medium : max_reward 0.93
        - expert : max_reward 1.99545 
        - medium_expert : max_reward 0.7 
    """
    if args.dataset_type == 'expert':
        conf.max_reward = 1.99545
    elif args.dataset_type == 'medium':
        conf.max_reward = 0.93
    elif args.dataset_type == 'expert_replay':
        conf.max_reward = 0.7
    else:
        print("Error, we dont have define any max reward threshold for this kind of dataset type.")
        raise NotImplementedError(args.dataset_type)

    conf.num_first_synthetic_data = 50000
    conf.num_actor_critic_training_per_loop = 200
    conf.num_evaluate = 50

    return conf












