from dataclasses import dataclass

from offline import base
from offline.modules.mlp import NONLINEARITIES
from offline.utils.parser import ArgumentParser


@dataclass(frozen=True)
class Arguments(base.Arguments):
    batch_size: int
    beta_schedule: str
    clip_sample: bool
    diffusion_steps: int
    hidden_features: int
    layer_norm: bool
    learning_rate: float
    nonlinearity: str
    num_layers: int
    temperature: float
    time_dim: int


def build_argument_parser(parser: ArgumentParser | None = None, **kwargs):
    if parser is None:
        parser = base.build_argument_parser(**kwargs)

    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument(
        "--beta-schedule", choices=["cosine", "linear", "vp"], default="vp"
    )
    parser.add_argument("--diffusion-steps", type=int, default=5)
    parser.add_argument(
        "--do-not-clip-sample", action="store_false", dest="clip_sample"
    )
    parser.add_argument("--hidden-features", type=int, default=256)
    parser.add_argument("--learning-rate", type=float, default=3e-4)
    parser.add_argument(
        "--no-layer-norm", action="store_false", dest="layer_norm"
    )
    parser.add_argument(
        "--nonlinearity", choices=list(NONLINEARITIES.keys()), default="mish"
    )
    parser.add_argument("--num-layers", type=int, default=5)
    parser.add_argument("--temperature", type=float, default=1)
    parser.add_argument("--time-dim", type=int, default=16)
    return parser
