import torch
import torch.nn as nn
from typing import List, Optional
from .mixture_utils import get_log_prob


class FlowMixtureExperts(nn.Module):
    """
    A module to manage a mixture of normalizing flow experts.

    This module encapsulates a list of flow models, including an optional
    background component. It computes the log-probabilities for a batch of data
    across all experts in a single forward pass, correctly handling disabled
    components.
    """

    def __init__(
        self,
        component_flows: List[nn.Module],
        background_flow: Optional[nn.Module] = None,
    ):
        super().__init__()
        self.n_components = len(component_flows)
        self.use_background_component = background_flow is not None

        self.component_flows = nn.ModuleList(component_flows)
        if self.use_background_component:
            self.background_flow = background_flow
            self.n_total_components = self.n_components + 1
        else:
            self.background_flow = None
            self.n_total_components = self.n_components

        self.register_buffer(
            "disabled_mask", torch.zeros(self.n_components, dtype=torch.bool)
        )

    def forward(self, y: torch.Tensor) -> torch.Tensor:
        """
        Calculates the log-probabilities for each component for the input data y.

        Args:
            y (torch.Tensor): The input data tensor of shape [batch_size, y_dim].

        Returns:
            torch.Tensor: A tensor of component log-probabilities of shape
                          [batch_size, n_total_components].
        """
        log_prob_disabled = -1e10

        log_probs = []
        for i, flow in enumerate(self.component_flows):
            if self.disabled_mask[i]:
                # Use a large negative number for disabled components
                log_p = torch.full((y.shape[0],), log_prob_disabled, device=y.device)
            else:
                log_p = get_log_prob(flow, y)
            log_probs.append(log_p)

        if self.use_background_component:
            log_probs.append(get_log_prob(self.background_flow, y))

        return torch.stack(log_probs, dim=1)

    def disable_components(self, indices: List[int]):
        """Disables one or more components by index."""
        for i in indices:
            if 0 <= i < self.n_components:
                self.disabled_mask[i] = True

    def enable_components(self, indices: List[int]):
        """Enables one or more components by index."""
        for i in indices:
            if 0 <= i < self.n_components:
                self.disabled_mask[i] = False
