import argparse
import json
from optimal_agents.utils.loader import ModelParams

def boolean(item):
    if item == 'true' or item == 'True':
        return True
    elif item == 'false' or item == 'False':
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

BASE_ARGS = {
    # Exclude Env and Alg as these are required for creating the params object.
    'evo_alg': str,
    'env_wrapper' : str,
    'morphology': str,
    'arena' : str,
    'timesteps' : int,
    'policy' : str,
    'early_reset' : boolean,
    'normalize' : boolean,
    'seed' : int,
    'log_interval' : int,
    'tensorboard' : str,
    'name' : str,
    'num_proc' : int,
    'eval_freq' : int,
    'checkpoint_freq' : int,
    'save_gif' : boolean,
    'finetune' : str,
    'graph_vec' : boolean,
    'population_size' : int,
    'num_generations': int,
}

ENV_ARGS = {
    'pad_actions': boolean, 
    'time_limit': float,
    'control_timestep': float,
    'n_sub_steps': int,
    'allow_exceptions' : boolean,
    'action_penalty' : float,
    'direction': int,
}

ENV_WRAPPER_ARGS = {
    'keys' : (str, '+'),
    'include_embeddings' : boolean,
    'include_pos' : boolean,
}

ALG_ARGS = {
    'learning_rate': float,
    'batch_size': int,
    'buffer_size': int,
    'learning_starts': int,
    'ent_coef': float,
    'n_epochs' : int,
    'n_updates' : int,
    'n_steps' : int,
    'aux_loss_type': str,
    'aux_loss_coef' : float,
    'learning_starts' : int,
}

POLICY_ARGS = {
    'net_arch' : json.loads, #(int, '+'), # json.loads
    'graph_conv_class' : str,
    'activation_fn' : str,
    'rbf' : int,
    'bandwidth' : float,
    'seperate_value_net' : boolean,
    'num_samples' : int,
    'horizon' : int,
    'cache_steps' : int,
}

NODE_ARGS = {
    'extent_range': float,
    'gear_range' : (int, '+'),
    'radius_range': (float, '+'),
    'joint_range': (int, '+'),
    'discrete_bins' : int,
    'only_ends' : boolean,
}

MORPHOLOGY_ARGS = {
    'two_dim': boolean,
    'geom_kwargs': json.loads,
    'joint_kwargs': json.loads,
    'global_kwargs': json.loads,
    'child_prob': float,
}

MUTATION_ARGS = {
    'min_nodes' : int,
    'max_nodes': int,
    'max_children': int,
    'extent_std' : float,
    'radius_std': float,
    'gear_std' : float,
    'joint_std': float,
    'attachment_std': float,
    'remove_prob' : float,
    'gen_prob': float,
    'geom_mut' : float,
    'node_prob' : float,
    'joint_mut' : float,
    'joint_prob' : float,
    'joint_type_mut' : float,
}

ARENA_ARGS = {

}

EVO_ARGS = {
    'num_cores' : int,
    'keep_percent': float,
    'random_percent': float,
    'new_percent' : float,
    'retrain': boolean,
    'remut_prob' : float,
    'cpus_per_ind' : int,
    'use_taskset' : boolean,
    'pruning_multiplier' : float,
    'pruning_start' : int,
    'pruning_lr' : float,
    'pruning_batch_size' : int,
    'pruning_data_size' : int,
    'pruning_n_epochs' : int,
    'pruning_eval_steps' : int,
    'pruning_buffer_size' : int,
    'thompson' : boolean,
    'save_freq' : int,
    'eval_ep' : int,
    'nge_mutation' : boolean,
    'gumble_noise' : boolean,
    'gumble_temperature' : float,
    'use_hidden_layer' : boolean,
    'use_negatives' : boolean,
    'use_thompson' : boolean,
    'fitness_window' : int,
    'ascent_structure_freq' : int,
    'use_segment_embeddings' : boolean,
    'ascent_lr' : float,
    'ascent_steps' : int,
    'ascent_all' : boolean,
    'mutate_structure_freq' : int,
    'normalize_per_gen' : boolean,
    'learn_morphology' : boolean,
    'morphology_lr' : float,
    'morphology_horizon': int,
    'morphology_steps': int,
    'morphology_ignore_collisons' : boolean,
    'log_smoothing' : boolean,
    'end_pool' : boolean,
    'global_state' : boolean,
    'ignore_collisions': boolean,
    # Arguments for entropoy evolution
    'state_noise' : float,
    'sample_freq' : float,
    'num_phases' : int,
    'num_freqs': int,
    'matching_noise' : boolean,
    'reset_freq' : int,
    'eval_envs': (str, '+'),
    'random_policy' : str,
    'pruning_arch' : json.loads,
    'classifier' : str,
    'seperate_eval' : boolean,
    'include_end' : boolean,
    'include_segments' : boolean,
    'include_start_state' : boolean,
    'action_ent_coef' : float,
}

def add_args_from_dict(parser, arg_dict):
    for arg_name, arg_type in arg_dict.items():
        arg_name = "--" + arg_name.replace('_', '-')
        if isinstance(arg_type, tuple) and len(arg_type) == 2:
            parser.add_argument(arg_name, type=arg_type[0], nargs=arg_type[1], default=None)
        else:
            parser.add_argument(arg_name, type=arg_type, default=None)

def base_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str)
    parser.add_argument("--alg", type=str)
    add_args_from_dict(parser, BASE_ARGS)
    return parser

def train_parser():
    parser = base_parser()
    add_args_from_dict(parser, ENV_ARGS)
    add_args_from_dict(parser, ENV_WRAPPER_ARGS)
    add_args_from_dict(parser, ALG_ARGS)
    add_args_from_dict(parser, POLICY_ARGS)
    add_args_from_dict(parser, MORPHOLOGY_ARGS)
    add_args_from_dict(parser, MUTATION_ARGS)
    add_args_from_dict(parser, NODE_ARGS)
    add_args_from_dict(parser, ARENA_ARGS)
    add_args_from_dict(parser, EVO_ARGS)

    return parser

def args_to_params(args):
    params = ModelParams(args.env, args.alg)
    for arg_name, arg_value in vars(args).items():
        if not arg_value is None and not arg_value == 'None':
            if arg_name in BASE_ARGS or arg_name in ("env", "alg"):
                params[arg_name] = arg_value
            elif arg_name in ENV_ARGS:
                params['env_args'][arg_name] = arg_value
            elif arg_name in ENV_WRAPPER_ARGS:
                params['env_wrapper_args'][arg_name] = arg_value
            elif arg_name in ALG_ARGS:
                params['alg_args'][arg_name] = arg_value
            elif arg_name in POLICY_ARGS:
                params['policy_args'][arg_name] = arg_value
            elif arg_name in MORPHOLOGY_ARGS:
                params['morphology_args'][arg_name] = arg_value
            elif arg_name in NODE_ARGS:
                params['node_args'][arg_name] = arg_value
            elif arg_name in MUTATION_ARGS:
                params['mutation_args'][arg_name] = arg_value
            elif arg_name in ARENA_ARGS:
                params['arena_args'][arg_name] = arg_value
            elif arg_name in EVO_ARGS:
                params['evo_alg_args'][arg_name] = arg_value
            else:
                raise ValueError("Provided argument does not fit into categories")
    return params

def convert_kwargs_args(kwargs, parser):
    arg_list = []
    for key in kwargs.keys():
        arg_list.append('--' + key.replace('_', '-'))
        if isinstance(kwargs[key], list):
            arg_list.extend([str(item) for item in kwargs[key]])
        else:
            arg_list.append(str(kwargs[key]))
    args, unknown_args = parser.parse_known_args(arg_list)
    if len(unknown_args) > 0:
        print("############# ERROR #####################################")
        print("Unknown Arguments:", unknown_args)
        print("#########################################################")
    return args
