from dataclasses import dataclass

from offline import base
from offline.utils.parser import ArgumentParser


@dataclass(frozen=True)
class Arguments(base.Arguments):
    batch_size: int
    bc_steps: int
    constant_schedule: bool
    ensemble_size: int
    gamma: float
    hidden_features: int
    layer_norm: bool
    learning_rate: float
    lipschitz_constant: float
    max_gradient_norm: float
    noise_ratio: float
    num_layers: int
    ood_probability: float
    pretrain_steps: int
    regularizer_weight: float
    sparse: bool
    std_multiplier: float
    tau: float
    update_every: int
    v_learning_steps: int
    zero_mean: bool


def build_argument_parser(parser: ArgumentParser | None = None, **kwargs):
    if parser is None:
        parser = base.build_argument_parser(**kwargs)

    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--bc-steps", type=int, default=500000)
    parser.add_argument(
        "--cosine-schedule", action="store_false", dest="constant_schedule"
    )
    parser.add_argument("--ensemble-size", type=int, default=4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--hidden-features", type=int, default=256)
    parser.add_argument("--learning-rate", type=float, default=3e-4)
    parser.add_argument("--lipschitz-constant", type=float, default=2.0)
    parser.add_argument("--max-gradient-norm", type=float, default=1)
    parser.add_argument(
        "--no-layer-norm", action="store_false", dest="layer_norm"
    )
    parser.add_argument("--noise-ratio", type=float, default=0.5)
    parser.add_argument("--num-layers", type=int, default=4)
    parser.add_argument("--ood-probability", type=float, default=0.5)
    parser.add_argument("--pretrain-steps", type=int, default=0)
    parser.add_argument("--regularizer-weight", type=float, default=0.005)
    parser.add_argument("--sparse", action="store_true")
    parser.add_argument("--std-multiplier", type=float, default=2)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--update-every", type=int, default=2)
    parser.add_argument("--v-learning-steps", type=int, default=500000)
    parser.add_argument("--zero-mean", action="store_true")
    parser.set_defaults(
        normalize_rewards=True, reward_multiplier=100, unsquash=True
    )
    return parser
