import argparse
import sys


def parse_args(string=None, default_args=None):
    parser = argparse.ArgumentParser()
    required = parser.add_argument_group("required named arguments")
    required.add_argument(
        "-hl",
        "--hidden_layers",
        nargs="*",
        default=[],
        help="List of models to use for aDQN",
        required=True,
    )
    required.add_argument(
        "-ms",
        "--m_seeds",
        nargs="*",
        default=[],
        help="List of seeds for model initialization",
        required=True,
    )
    required.add_argument(
        "-s", "--t_seed", type=int, help="Training seed", required=True
    )

    parser.add_argument(
        "-p", "--path", action="store", type=str, help="Path to save results"
    )
    parser.add_argument(
        "-c",
        "--criterion",
        default="min",
        const="min",
        nargs="?",
        choices=["min", "max", "random", "eps_min"],
        help="Criterion for model selection in aDQN",
    )
    parser.add_argument(
        "-g",
        "--gamma",
        type=float,
        default=0.99,
        help="discount factor of the environment",
    )
    parser.add_argument(
        "-hz", "--horizon", type=int, default=1000, help="horizon of the environment"
    )
    parser.add_argument(
        "-bs", "--batch_size", type=int, default=32, help="batch size for training"
    )
    parser.add_argument(
        "-is",
        "--n_initial_samples",
        type=int,
        default=1000,
        help="number of initial samples in the replay buffer",
    )
    parser.add_argument(
        "-ef",
        "--eps_final",
        type=float,
        default=0.01,
        help="final value of epsilon for exploration",
    )
    parser.add_argument(
        "-efs",
        "--eps_final_steps",
        type=int,
        default=1000,
        help="number of steps over which epsilon is annealed",
    )
    parser.add_argument(
        "-rbs",
        "--replay_buffer_size",
        type=int,
        default=10000,
        help="size of the replay buffer",
    )
    parser.add_argument(
        "-ne",
        "--n_epochs",
        type=int,
        default=35,
        help="number of epochs for the training",
    )
    parser.add_argument(
        "-mspe",
        "--min_steps_per_epoch",
        type=int,
        default=20000,
        help="minimum number of steps per epoch",
    )
    parser.add_argument(
        "-tuf",
        "--target_update_frequency",
        type=int,
        default=5000,
        help="frequency of updating the target network",
    )
    parser.add_argument(
        "-ve",
        "--val_episodes",
        type=int,
        default=10,
        help="number of episodes used for validation",
    )
    parser.add_argument(
        "-lr", "--learning_rate", type=float, default=0.0003, help="learning rate"
    )
    if default_args is not None:
        parser.set_defaults(**default_args)

    if string is None:
        print("Flags:\n", sys.argv)
    else:
        print("Flags:\n", string)
    args = parser.parse_args(string)
    models = []
    for elem in args.hidden_layers:
        models.append([int(neurons) for neurons in elem.split("-")])
    args.hidden_layers = models
    for i in range(len(args.m_seeds)):
        args.m_seeds[i] = int(args.m_seeds[i])
    if len(args.m_seeds) != len(args.hidden_layers):
        raise ValueError(
            f"Number of model seeds {len(args.m_seeds)} does not match number of models {len(args.hidden_layers)}"
        )
    return args


if __name__ == "__main__":
    parse_args()
