from dataclasses import dataclass

from offline import base


@dataclass(frozen=True)
class Arguments(base.Arguments):
    alpha: float
    batch_size: int
    gamma: float
    hidden_features: int
    layer_norm: bool
    learning_rate: float
    noise_clip: float
    num_layers: int
    policy_noise: float
    tau: float
    update_every: int


def build_argument_parser(**kwargs):
    parser = base.build_argument_parser(**kwargs)
    parser.add_argument("--alpha", type=float, default=2.5)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--hidden-features", type=int, default=256)
    parser.add_argument("--learning-rate", type=float, default=3e-4)
    parser.add_argument(
        "--no-layer-norm", action="store_false", dest="layer_norm"
    )
    parser.add_argument("--noise-clip", type=float, default=0.5)
    parser.add_argument("--num-layers", type=int, default=4)
    parser.add_argument("--policy-noise", type=float, default=0.2)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--update-every", type=int, default=2)
    return parser
