from dataclasses import dataclass
from typing import Optional

import torch as t


@dataclass
class CombinedLoss:
    scalar_loss: t.Tensor
    batched_loss: t.Tensor
    secondary_loss: Optional[t.Tensor] = None


@dataclass
class VanillaLoss:
    mse_reconstruction_loss_batched: t.Tensor
    l1_loss_batched: t.Tensor


@dataclass
class FSQLoss:
    mse_reconstruction_loss_batched: t.Tensor
    l2_loss_batched: t.Tensor


@dataclass
class TopMLoss:
    mse_reconstruction_loss_batched: t.Tensor
    l1_loss_batched: t.Tensor
    decorrelation_loss: t.Tensor


@dataclass
class GatedLoss:
    mse_reconstruction_loss_batched: t.Tensor
    gating_sparsity_loss_batched: t.Tensor
    gating_reconstruction_loss_batched: t.Tensor


@dataclass
class MoaeLoss:
    mse_reconstruction_loss_batched: t.Tensor
    gating_sparsity_loss_batched: t.Tensor
    gating_reconstruction_loss_batched: t.Tensor


@dataclass
class JumpSparsityLoss:
    mse_reconstruction_loss_batched: t.Tensor
    jump_sparsity_loss_batched: t.Tensor


@dataclass
class BinaryPlusLoss:
    mse_reconstruction_loss_batched: t.Tensor
    l1_loss_batched: t.Tensor
    commitment_loss_batched: t.Tensor
    codebook_loss_batched: t.Tensor


@dataclass
class CausalLoss:
    mse_reconstruction_loss_batched: t.Tensor
    l1_loss_batched: t.Tensor
    commitment_loss_batched: t.Tensor
    codebook_loss_batched: t.Tensor
    sparsity_loss_batched: t.Tensor


@dataclass
class PushThroughTopMLoss:
    mse_reconstruction_loss_batched: t.Tensor
    l1_loss_batched: t.Tensor
    decorrelation_loss: t.Tensor
    feature_reconstruction_loss_batched: t.Tensor
