from typing import Dict, List, Union
import argparse
import jax
import flax.linen as nn
import optax


OPTIMIZER_CLASSES = {"adam": optax.adam, "adamw": optax.adamw, "rmsprop": optax.rmsprop}
ACTIVATION_FNS = {"relu": nn.relu, "tanh": nn.tanh, "sigmoid": nn.sigmoid}


def parse(argvs: List[str]) -> Dict[str, Union[str, int, float]]:
    parser = argparse.ArgumentParser("Arguments for the aSAC library.")

    parser.add_argument("-e", "--env", type=str, required=True, help="environment name.")
    parser.add_argument("-s", "--seed", type=int, required=True, help="seed.")
    parser.add_argument(
        "-w",
        "--wandb_project",
        type=str,
        help="wandb project name. If not given, the log are not sent to wandb.",
    )
    parser.add_argument("-n", "--n_samples", type=int, required=True, help="number of collected samples.")
    parser.add_argument("-lr", "--learning_rates", type=float, required=True, nargs="*", help="critic learning rates.")
    parser.add_argument(
        "-nis",
        "--n_initial_samples",
        type=int,
        required=True,
        help="number of samples before learning starts.",
    )
    parser.add_argument(
        "-utd",
        "--utd",
        type=int,
        required=True,
        help="update to data.",
    )
    parser.add_argument(
        "-oc",
        "--optimizer_classes",
        type=str,
        required=True,
        nargs="*",
        choices=["adam", "adamw", "rmsprop"],
        help="optimizer class of the critics.",
    )
    parser.add_argument(
        "-naq",
        "--net_archs_qf",
        type=str,
        required=True,
        nargs="*",
        help="architecture of the critics.",
    )
    parser.add_argument(
        "-af",
        "--activation_fns",
        type=str,
        required=True,
        nargs="*",
        choices=["relu", "tanh", "sigmoid"],
        help="activation function of the critics.",
    )
    parser.add_argument(
        "-mc",
        "--m_critics",
        type=int,
        required=True,
        help="number of critics to sample from for computing the target and for the policy in the actor's loss (in case all_policy_qf is False).",
    )
    parser.add_argument(
        "-rt",
        "--random_target_qf",
        default=False,
        action="store_true",
        help="if true, the target networks will be selected randomly.",
    )
    parser.add_argument(
        "-at",
        "--aggregate_target_qf",
        type=str,
        required=True,
        choices=["min", "mean"],
        help="function to aggregate the selected networks to compute the target.",
    )
    parser.add_argument(
        "-allp",
        "--all_policy_qf",
        default=False,
        action="store_true",
        help="if true, all the network will be used in the actor's loss.",
    )
    parser.add_argument(
        "-ee",
        "--end_epsilon",
        type=float,
        required=True,
        help="end episilon for uniformly sampling the critics to compute the policy in the actor's loss.",
    )
    parser.add_argument(
        "-de",
        "--duration_epsilon",
        type=int,
        required=True,
        help="duration of the episilon schedule for uniformly sampling the critics to compute the policy in the actor's loss.",
    )
    parser.add_argument(
        "-ap",
        "--aggregate_policy_qf",
        type=str,
        required=True,
        choices=["min", "mean"],
        help="function to aggregate the selected networks in the actor's loss.",
    )

    args = vars(parser.parse_args(argvs))

    args["group"] = (
        f"lr{'-'.join(map(str, args['learning_rates']))}_oc{'-'.join(args['optimizer_classes'])}_naq{'-'.join(args['net_archs_qf'])}_af{'-'.join(args['activation_fns'])}_rt{int(args['random_target_qf'])}_ee{args['end_epsilon']}/{args['env']}"
    )
    args["env"] = args["env"] + "-v4"
    args["optimizer_classes"] = [OPTIMIZER_CLASSES[optimizer_class] for optimizer_class in args["optimizer_classes"]]
    # Times 2 to double the number of critics incase only one setting is performing well
    args["net_archs_qf"] = sorted(2 * jax.tree.map(int, [elem.split("_") for elem in args["net_archs_qf"]]))
    args["activation_fns"] = [ACTIVATION_FNS[activation_fn] for activation_fn in args["activation_fns"]]

    return args
