from dataclasses import dataclass
from typing import Optional, TypeVar

import torch as t
from collectibles import ListCollection as LC

T = TypeVar("T")


@dataclass
class AutoEncoderMetrics:
    overall_loss: float

    mse_reconstruction_loss: float
    l0_sparsity_metric: float
    naive_description_length_bits: Optional[float] = None

    l1_sparsity_loss: Optional[float] = None
    gating_reconstruction_loss: Optional[float] = None
    l2_unquantised_sparsity_loss: Optional[float] = None

    switch_load_balancing_loss: Optional[float] = None
    router_z_loss: Optional[float] = None
    expert_importance_loss: Optional[float] = None

    expert_usage: Optional[t.Tensor] = None
    proportion_tokens_routed: Optional[float] = None

    downstream_loss_recovered: Optional[float] = None
    downstream_entropy: Optional[float] = None

    decorr_score: Optional[float] = None

    nfm_loss: Optional[float] = None
    nfm_inf_loss: Optional[float] = None
    multi_info_loss: Optional[float] = None
    hessian_penalty_loss: Optional[float] = None

    jump_l0_loss: Optional[float] = None

    codebook_loss: Optional[float] = None
    commitment_loss: Optional[float] = None

    feature_reconstruction_loss: Optional[float] = None

    aux_dead_reconstruction_loss: Optional[float] = None
    aux_dying_reconstruction_loss: Optional[float] = None

    def __post_init__(self):
        if self.naive_description_length_bits is None:
            self.naive_description_length_bits = self.l0_sparsity_metric * 8

    def proxy_sweep_metric(self, num_features: int, density_penalty: float) -> float:
        l0_norm = self.l0_sparsity_metric
        feature_density = (
            l0_norm / num_features
        )  # between 0 and 1, 0 if no features active, 1 if all features active

        assert self.downstream_loss_recovered is not None
        proxy_sweep_metric = self.downstream_loss_recovered - feature_density * density_penalty
        # Above is saying max downstream loss recovered is good, but also want to penalise for high feature density
        # Every 1% we increase feature density, needs to be offset by 5% increase in downstream loss recovered.

        return proxy_sweep_metric

    def feature_density(self, num_features: int) -> float:
        l0_norm = self.l0_sparsity_metric
        feature_density = l0_norm / num_features
        return feature_density


class MetricsCollection(LC[AutoEncoderMetrics]):
    overall_loss: list[float]
    mse_reconstruction_loss: list[float]
    l0_sparsity_metric: list[float]
    naive_description_length_bits: list[float]
    l1_sparsity_loss: list[Optional[float]]
    gating_reconstruction_loss: list[Optional[float]]
    switch_load_balancing_loss: list[Optional[float]]
    router_z_loss: list[Optional[float]]
    expert_importance_loss: list[Optional[float]]
    expert_usage: list[Optional[t.Tensor]]
    proportion_tokens_routed: list[Optional[float]]
    downstream_loss_recovered: list[Optional[float]]
    downstream_entropy: list[Optional[float]]
    decorr_score: list[Optional[float]]
    nfm_loss: list[Optional[float]]
    nfm_inf_loss: list[Optional[float]]
    multi_info_loss: list[Optional[float]]
    hessian_penalty_loss: list[Optional[float]]
    jump_l0_loss: list[Optional[float]]
    codebook_loss: list[Optional[float]]
    commitment_loss: list[Optional[float]]
    feature_reconstruction_loss: list[Optional[float]]
    aux_dead_reconstruction_loss: list[Optional[float]]
    aux_dying_reconstruction_loss: list[Optional[float]]

    def reduce(self) -> AutoEncoderMetrics:

        return AutoEncoderMetrics(
            overall_loss=self.mean(self.overall_loss),
            mse_reconstruction_loss=self.mean(self.mse_reconstruction_loss),
            l0_sparsity_metric=self.mean(self.l0_sparsity_metric),
            naive_description_length_bits=self.mean(self.naive_description_length_bits),
            l1_sparsity_loss=self.mean_without_nones(self.l1_sparsity_loss),
            gating_reconstruction_loss=self.mean_without_nones(
                self.gating_reconstruction_loss
            ),
            switch_load_balancing_loss=self.mean_without_nones(
                self.switch_load_balancing_loss
            ),
            router_z_loss=self.mean_without_nones(self.router_z_loss),
            expert_importance_loss=self.mean_without_nones(self.expert_importance_loss),
            expert_usage=self.mean_without_nones(self.expert_usage),
            downstream_loss_recovered=self.mean_without_nones(self.downstream_loss_recovered),
            downstream_entropy=self.mean_without_nones(self.downstream_entropy),
            proportion_tokens_routed=self.mean_without_nones(self.proportion_tokens_routed),
            decorr_score=self.mean_without_nones(self.decorr_score),
            nfm_loss=self.mean_without_nones(self.nfm_loss),
            nfm_inf_loss=self.mean_without_nones(self.nfm_inf_loss),
            multi_info_loss=self.mean_without_nones(self.multi_info_loss),
            hessian_penalty_loss=self.mean_without_nones(self.hessian_penalty_loss),
            jump_l0_loss=self.mean_without_nones(self.jump_l0_loss),
            codebook_loss=self.mean_without_nones(self.codebook_loss),
            commitment_loss=self.mean_without_nones(self.commitment_loss),
            feature_reconstruction_loss=self.mean_without_nones(
                self.feature_reconstruction_loss
            ),
            aux_dead_reconstruction_loss=self.mean_without_nones(
                self.aux_dead_reconstruction_loss
            ),
            aux_dying_reconstruction_loss=self.mean_without_nones(
                self.aux_dying_reconstruction_loss
            ),
        )

    @staticmethod
    def mean_without_nones(values: list[T]) -> Optional[T]:
        filtered_values = [value for value in values if value is not None]

        if len(values) == 0:
            return None

        return sum(filtered_values) / len(values)  # type: ignore

    @staticmethod
    def mean(values: list[T]) -> T:
        filtered_values = [value for value in values if value is not None]

        if len(values) == 0:
            raise ValueError("Cannot calculate mean of empty list")

        return sum(filtered_values) / len(filtered_values)  # type: ignore


if __name__ == "__main__":
    metrics_collection = MetricsCollection(
        [
            AutoEncoderMetrics(
                overall_loss=0.1,
                mse_reconstruction_loss=0.2,
                l0_sparsity_metric=0.3,
                l1_sparsity_loss=0.4,
                gating_reconstruction_loss=0.5,
                switch_load_balancing_loss=0.6,
                router_z_loss=0.7,
                expert_importance_loss=0.8,
                expert_usage=t.tensor([0.9, 1.0]),
                proportion_tokens_routed=1.1,
                downstream_loss_recovered=1.2,
                downstream_entropy=1.3,
            ),
            AutoEncoderMetrics(
                overall_loss=0.2,
                mse_reconstruction_loss=0.3,
                l0_sparsity_metric=0.4,
                l1_sparsity_loss=0.5,
                gating_reconstruction_loss=0.6,
                switch_load_balancing_loss=0.7,
                router_z_loss=0.8,
                expert_importance_loss=0.9,
                expert_usage=t.tensor([1.0, 1.1]),
                proportion_tokens_routed=1.2,
                downstream_loss_recovered=1.3,
                downstream_entropy=1.4,
            ),
        ]
    )

    reduced_metrics = metrics_collection.reduce()
    print(reduced_metrics)
    print(reduced_metrics.proxy_sweep_metric(10, 5.0))
    print(reduced_metrics.feature_density(10))
