from dataclasses import dataclass

from offline.diffusion import arguments as diffusion
from offline.utils.parser import ArgumentParser


@dataclass(frozen=True)
class Arguments(diffusion.Arguments):
    eta: float
    inference_steps: int
    timestep_spacing: str


def build_argument_parser(parser: ArgumentParser | None = None, **kwargs):
    if parser is None:
        parser = diffusion.build_argument_parser(**kwargs)

    parser.add_argument("--eta", type=float, default=0)
    parser.add_argument("--inference-steps", type=int, default=5)
    parser.add_argument(
        "--timestep-spacing",
        choices=["leading", "linspace", "trailing"],
        default="leading",
    )
    parser.set_defaults(diffusion_steps=100)
    return parser
