"""
Data model for the riffusion API.
"""
from __future__ import annotations

import typing as T
from dataclasses import dataclass


@dataclass(frozen=True)
class PromptInput:
    """
    Parameters for one end of interpolation.
    """

    # Text prompt fed into a CLIP model
    prompt: str

    # Random seed for denoising
    seed: int

    # Negative prompt to avoid (optional)
    negative_prompt: T.Optional[str] = None

    # Denoising strength
    denoising: float = 0.75

    # Classifier-free guidance strength
    guidance: float = 7.0


@dataclass(frozen=True)
class InferenceInput:
    """
    Parameters for a single run of the riffusion model, interpolating between
    a start and end set of PromptInputs. This is the API required for a request
    to the model server.
    """

    # Start point of interpolation
    start: PromptInput

    # End point of interpolation
    end: PromptInput

    # Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
    # uses end fully.
    alpha: float

    # Number of inner loops of the diffusion model
    num_inference_steps: int = 50

    # Which seed image to use
    seed_image_id: str = "og_beat"

    # ID of mask image to use
    mask_image_id: T.Optional[str] = None


@dataclass(frozen=True)
class InferenceOutput:
    """
    Response from the model inference server.
    """

    # base64 encoded spectrogram image as a JPEG
    image: str

    # base64 encoded audio clip as an MP3
    audio: str

    # The duration of the audio clip
    duration_s: float
