from torch import Tensor
from typing import Optional, List
from dataclasses import dataclass
from transformers.utils import ModelOutput


@dataclass
class MiMoEOutput(ModelOutput):
    last_hidden_state: Tensor
    router_scores: Optional[List[Tensor]] = None
    router_logits: Optional[List[Tensor]] = None
    buffer_ratios: Optional[List[float]] = None
    routing_rate: Optional[float] = None