from dataclasses import dataclass
from typing import Optional, Union

import torch as t
from einops import rearrange
from jaxtyping import Float

from auto_encoder.helpers.ae_metrics import AutoEncoderMetrics


@dataclass
class VanillaAEOutput:
    x_hat_BSN: t.Tensor
    features_BSF: t.Tensor

    normalising_constants_BS: t.Tensor


@dataclass
class MutualChoiceAEOutput:
    x_hat_BSN: t.Tensor
    features_BSF: t.Tensor
    l1_sparsity_loss: t.Tensor

    aux_dead_loss: t.Tensor
    aux_dying_loss: t.Tensor

    nfm_loss: t.Tensor
    nfm_inf_loss: t.Tensor

    normalising_constants_BS: t.Tensor


AutoEncoderOutput = Union[
    VanillaAEOutput,
    MutualChoiceAEOutput,
]


@dataclass
class EncoderExpertOutput:
    features_BSF: t.Tensor
    active_features_pre_binarisation_BSF: t.Tensor
    active_features_BSF: t.IntTensor
    feature_magnitudes_BSF: t.Tensor


class EncoderOutputCollection(list[EncoderExpertOutput]):
    def __init__(self, encoder_expert_outputs: list[EncoderExpertOutput]):
        super().__init__(encoder_expert_outputs)

        self.features = t.vstack([e.features_BSF for e in encoder_expert_outputs])
        self.active_features_pre_binarisation = t.vstack(
            [e.active_features_pre_binarisation_BSF for e in encoder_expert_outputs]
        )
        self.active_features = t.vstack(
            [e.active_features_BSF for e in encoder_expert_outputs]
        )
        self.feature_magnitudes = t.vstack(
            [e.feature_magnitudes_BSF for e in encoder_expert_outputs]
        )

    def seq_ordered_outputs(
        self, order: t.Tensor, batch_size: int
    ) -> tuple[t.Tensor, t.Tensor, t.Tensor, t.Tensor]:
        features = self.features[order.argsort()]
        active_features_pre_binarisation = self.active_features_pre_binarisation[
            order.argsort()
        ]
        active_features = self.active_features[order.argsort()].int()
        feature_magnitudes = self.feature_magnitudes[order.argsort()]

        assert active_features.dtype == t.int

        (
            features,
            active_features_pre_binarisation,
            active_features_gate,
            feature_magnitudes,
        ) = self._rearrange_uncouple_batch_seq_len(
            features=features,
            active_features_pre_binarisation=active_features_pre_binarisation,
            active_features_gate=active_features,
            feature_magnitudes=feature_magnitudes,
            batch_size=batch_size,
        )  # batch seq_len dict_size

        return (
            features,
            active_features_pre_binarisation,
            active_features_gate,
            feature_magnitudes,
        )

    @staticmethod
    def _rearrange_uncouple_batch_seq_len(
        features: Float[t.Tensor, "(batch seq) dict_size"],
        active_features_pre_binarisation: Float[t.Tensor, "(batch seq) dict_size"],
        active_features_gate: Float[t.Tensor, "(batch seq) dict_size"],
        feature_magnitudes: Float[t.Tensor, "(batch seq) dict_size"],
        batch_size: int,
    ):

        features = rearrange(
            features, "(batch seq_len) dict_size -> batch seq_len dict_size", batch=batch_size
        )
        active_features_pre_binarisation = rearrange(
            active_features_pre_binarisation,
            "(batch seq_len) dict_size -> batch seq_len dict_size",
            batch=batch_size,
        )
        active_features_gate = rearrange(
            active_features_gate,
            "(batch seq_len) dict_size -> batch seq_len dict_size",
            batch=batch_size,
        )
        feature_magnitudes = rearrange(
            feature_magnitudes,
            "(batch seq_len) dict_size -> batch seq_len dict_size",
            batch=batch_size,
        )

        return (
            features,
            active_features_pre_binarisation,
            active_features_gate,
            feature_magnitudes,
        )


@dataclass
class EncoderFinalOutput:
    features_BSF: t.Tensor
    active_features_pre_binarisation_BSF: t.Tensor
    active_features_gate_BSF: t.Tensor
    feature_magnitudes_BSF: t.Tensor
    sparse_mixer_multiplier_Bs: Optional[t.Tensor]

    switch_load_balancing_loss: t.Tensor
    router_z_loss: t.Tensor
    expert_importance_loss: t.Tensor

    expert_usage_E: Optional[t.Tensor]
    proportion_tokens_hit: Optional[t.Tensor]


@dataclass
class SupermodelOutput:
    scalar_loss: t.Tensor
    feature_activations_BSF: t.Tensor
    initial_neuron_activations_BSN: t.Tensor
    reconstructed_neuron_activations_BSN: t.Tensor
    normalising_constants_BS: t.Tensor
    batched_loss_BS: Optional[t.Tensor]
    metrics: AutoEncoderMetrics

    secondary_loss: Optional[t.Tensor] = None
