from dataclasses import dataclass
from typing import Optional, Sequence
from diffusers.utils import BaseOutput
import torch


@dataclass
class EncoderOutput(BaseOutput):
    last_hidden_state: torch.Tensor
    hidden_states: Optional[Sequence[torch.Tensor]] = None
    mask: Optional[torch.Tensor] = None


@dataclass
class MAEEncoderOutput(BaseOutput):
    prompt_tokens: Optional[torch.Tensor] = (
        None  # like cls token or any other prompt tokens. Shape in [B, N_prompt, C]
    )
    patch_tokens: Optional[torch.Tensor] = None  # Patch tokens after mask drop
    mask: Optional[torch.Tensor] = (
        None  # mask tensor shape in [B, L_after_drop] or [B, H, W]. Should be binary.
    )
    ids_restore: Optional[torch.Tensor] = None  # ids_restore tensor shape in [B, L]
