import argparse
from functools import wraps
from typing import Callable, List


def output_added_arguments(add_algo_arguments: Callable) -> Callable:
    @wraps(add_algo_arguments)
    def decorated(parser: argparse.ArgumentParser) -> List[str]:
        unfiltered_old_arguments = list(parser._option_string_actions.keys())

        add_algo_arguments(parser)

        unfiltered_arguments = list(parser._option_string_actions.keys())
        unfiltered_added_arguments = [
            argument for argument in unfiltered_arguments if argument not in unfiltered_old_arguments
        ]

        return [
            argument.strip("-")
            for argument in unfiltered_added_arguments
            if argument.startswith("--") and argument not in ["--help"]
        ]

    return decorated


@output_added_arguments
def add_base_arguments(parser: argparse.ArgumentParser):
    parser.add_argument(
        "-en",
        "--experiment_name",
        help="Experiment name.",
        type=str,
        required=True,
    )
    parser.add_argument(
        "-s",
        "--seed",
        help="Seed of the experiment.",
        type=int,
        required=True,
    )
    parser.add_argument(
        "-dw",
        "--disable_wandb",
        help="Disable wandb.",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "-rbc",
        "--replay_buffer_capacity",
        help="Replay Buffer capacity.",
        type=int,
        default=1_000_000,
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        help="Batch size for training.",
        type=int,
        default=256,
    )
    parser.add_argument(
        "-nis",
        "--n_initial_samples",
        help="Number of initial samples before the training starts.",
        type=int,
        default=5_000,
    )
    parser.add_argument(
        "-uh",
        "--update_horizon",
        help="Value of n in n-step TD update.",
        type=int,
        default=1,
    )
    parser.add_argument(
        "-gamma",
        "--gamma",
        help="Discounting factor.",
        type=float,
        default=0.99,
    )
    parser.add_argument(
        "-horizon",
        "--horizon",
        help="Horizon for truncation.",
        type=int,
        default=1_000,
    )
    parser.add_argument(
        "-n",
        "--n_samples",
        type=int,
        help="Number of collected samples.",
        default=500_000,
    )
    parser.add_argument(
        "-tau",
        "--tau",
        help="Tau in target update.",
        type=float,
        default=5e-3,
    )
    parser.add_argument(
        "-bn",
        "--batch_norm",
        help="Flag to add batch norm.",
        default=False,
        action="store_true",
    )



def add_learning_rate_schedule(parser: argparse.ArgumentParser):
    parser.add_argument(
        "-lri",
        "--learning_rate_init",
        help="Initial learning rate.",
        type=float,
        default=1e-4,
    )
    parser.add_argument(
        "-lre",
        "--learning_rate_end",
        help="Ending learning rate.",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "-lrds",
        "--learning_rate_decay_steps",
        help="Length of the learning rate decay.",
        type=int,
        default=1_000_000,
    )


def add_n_bellman_iterations(parser: argparse.ArgumentParser):
    parser.add_argument(
        "-nbi",
        "--n_bellman_iterations",
        help="Number of bellman iterations to train in parallel. (K)",
        type=int,
        default=3,
    )
    parser.add_argument(
        "-ld",
        "--loss_discount",
        help="Discount value for q_value aggregation",
        type=float,
        default=1.0,
    )


@output_added_arguments
def add_simbav2_arguments(parser: argparse.ArgumentParser):
    add_learning_rate_schedule(parser)


@output_added_arguments
def add_tfsimbav2_arguments(parser: argparse.ArgumentParser):
    add_learning_rate_schedule(parser)


@output_added_arguments
def add_issimbav2_arguments(parser: argparse.ArgumentParser):
    add_learning_rate_schedule(parser)
    add_n_bellman_iterations(parser)
