from argparse import ArgumentParser
from typing import Any, Dict


def add_dataset_args(parser: ArgumentParser):
    parser.add_argument("--envs", type=int, required=False, default=100000, help="Envs")
    parser.add_argument("--envs_eval", type=int, required=False, default=100, help="Eval Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=100, help="Context horizon")
    parser.add_argument("--dim", type=int, required=False, default=10, help="Dimension")
    parser.add_argument("--lin_d", type=int, required=False, default=2, help="Linear feature dimension")

    parser.add_argument("--var", type=float, required=False, default=0.0, help="Bandit arm variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage of optimal arm")

    parser.add_argument("--env", type=str, required=True, help="Environment")

    parser.add_argument("--reward", type=str, choices=["sparse", "dense"], required=False, default="sparse", help="Reward type")

    parser.add_argument("--env_id_start", type=int, required=False, default=-1, help="Start index of envs to sample")
    parser.add_argument("--env_id_end", type=int, required=False, default=-1, help="End index of envs to sample")


def add_model_args(parser: ArgumentParser):
    parser.add_argument("--arch", type=str, required=False, default=None, help="Architecture to use, from a set of standard ones")

    parser.add_argument("--embd", type=int, required=False, default=32, help="Embedding size")
    parser.add_argument("--head", type=int, required=False, default=4, help="Number of heads")
    parser.add_argument("--layer", type=int, required=False, default=4, help="Number of layers")
    parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Learning Rate")
    parser.add_argument("--dropout", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--shuffle", default=False, action="store_true")


def add_adaptive_attacker_args(parser: ArgumentParser):
    parser.add_argument("--embd_att", type=int, required=True, help="Embedding size")
    parser.add_argument("--head_att", type=int, required=True, help="Number of heads")
    parser.add_argument("--layer_att", type=int, required=True, help="Number of layers")
    parser.add_argument("--dropout_att", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--shuffle_att", default=False, action="store_true")


def get_model_params_from_arch(args: Dict[str, Any]):
    if args["arch"] == "1":
        shuffle = True
        lr = 0.0001
        dropout = 0
        n_embd = 32
        n_layer = 4
        n_head = 4
    else:
        shuffle = args["shuffle"]
        lr = args["lr"]
        dropout = args["dropout"]
        n_embd = args["embd"]
        n_layer = args["layer"]
        n_head = args["head"]

    return shuffle, lr, dropout, n_embd, n_layer, n_head


def get_attacker_params_from_arch(args: Dict[str, Any]):
    shuffle = args["shuffle_att"]
    lr = args["attacker_lr"]
    dropout = args["dropout_att"]
    n_embd = args["embd_att"]
    n_layer = args["layer_att"]
    n_head = args["head_att"]

    return shuffle, lr, dropout, n_embd, n_layer, n_head


def add_train_args(parser: ArgumentParser):
    parser.add_argument("--num_epochs", type=int, required=False, default=1000, help="Number of epochs")


def add_corrupt_train_args(parser: ArgumentParser):
    parser.add_argument("--corrupt_train", type=str, default="", help="Corruption type for training")


def add_eval_args(parser: ArgumentParser):
    parser.add_argument("--epoch", type=int, required=False, default=-1, help="Epoch to evaluate")
    parser.add_argument("--test_cov", type=float, required=False, default=-1.0, help="Test coverage (for bandit)")
    parser.add_argument("--hor", type=int, required=False, default=-1, help="Episode horizon (for mdp)")
    parser.add_argument("--n_eval", type=int, required=False, default=100, help="Number of eval trajectories")
    parser.add_argument("--save_video", default=False, action="store_true")

    parser.add_argument("--eval_online", default=False, action="store_true", help="Evaluate online deployment")
    parser.add_argument("--eval_offline", default=False, action="store_true", help="Evaluate offline deployment")
    parser.add_argument("--corrupt", type=str, default="", help="Corruption type")
    add_corrupt_train_args(parser)


def add_logging_args(parser: ArgumentParser):
    parser.add_argument("--log", choices=["none", "wandb"], required=False, default="wandb", help="Online logging type to use")


def add_adv_training_args(parser: ArgumentParser):
    parser.add_argument("--n_rounds", type=int, required=False, default=20, help="Number of adversarial training rounds to perform")
    parser.add_argument("--eps_episodes", type=float, required=False, default=0.4, help="Fraction of episodes poisoned")
    parser.add_argument("--eps_steps", type=float, required=False, default=0.4, help="Fraction of steps within an episode poisoned")
    parser.add_argument("--victim_iters", type=int, required=False, default=20, help="Number of iterations the victim should be trained for per dataset")
    parser.add_argument("--victim_lr", type=float, required=False, default=0.00003, help="Learning rate for the victim")
    parser.add_argument("--attacker_iters", type=int, required=False, default=20, help="Number of iterations the attacker should be trained for per dataset")
    parser.add_argument("--attacker_lr", type=float, required=False, default=0.03, help="Learning rate for the attacker")
    parser.add_argument("--budget_regularizer", type=float, required=False, default=10, help="Regularizer constant for the budget in the loss function for the attacker")
    parser.add_argument("--max_poison_diff", type=float, required=False, default=3.0, help="Maximum distance from the original means the attacker is able to poison")
