"""Bagged ensemble normalizing flow."""
import torch
import zuko

from calnf.guides.guide import Guide
from calnf.datasets.dataset import Dataset


class FlowEnsembleDistribution(torch.distributions.Distribution):
    has_rsample = True

    def __init__(self, flow, contexts, weights, temperature):
        """Initialize the ensemble distribution.

        Args:
            flow (zuko.core.LazyDistribution): Flow.
            contexts (torch.Tensor): Contexts for the flow for each mixture component.
            weights (torch.Tensor): Weights for the mixture distribution.
            temperature (float): Temperature parameter for temp annealed sampling.
        """
        self.flow = flow
        self.contexts = contexts
        self.weights = weights
        if contexts.shape[0] != weights.shape[0]:
            raise ValueError("Number of contexts must match the number of weights.")

        self.temperature = temperature

    def rsample_and_log_prob(self, sample_shape=torch.Size()):
        """Sample from the ensemble distribution and return the log probability."""
        # Get a relaxed sample from the mixture distribution
        relaxed_sample = torch.distributions.RelaxedOneHotCategorical(
            self.temperature, self.weights
        )
        relaxed_sample = relaxed_sample.rsample(sample_shape)

        # Get the samples from the flows
        component_distribution = self.flow(self.contexts)
        samples, logprobs = component_distribution.rsample_and_log_prob(sample_shape)
        # samples will be [*sample_shape, self.contexts.shape[0], flow.event_shape]
        # logprobs will be [*sample_shape, self.contexts.shape[0]]

        # Combine the samples according to the mixture weights
        samples = torch.einsum("...i, ...ij -> ...j", relaxed_sample, samples)

        # Compute the probability (have to exponentiate, then sum, then log)
        logprobs = torch.logsumexp(logprobs + torch.log(self.weights), dim=-1)

        return samples, logprobs

    def log_prob(self, value):
        """Compute the log probability of a value."""
        # Add a batch dimension if necessary
        if value.dim() == 1:
            value = value.unsqueeze(0)

        component_distribution = self.flow(self.contexts)
        logprobs = component_distribution.log_prob(
            value.reshape(value.shape[0], 1, *value.shape[1:])
        )
        # logprobs will be [value.shape[0], self.contexts.shape[0]]

        # Compute the probability (have to sum, then log)
        logprobs = torch.logsumexp(logprobs + torch.log(self.weights), dim=-1)

        return logprobs

    def rsample(self, sample_shape=torch.Size()):
        """Sample from the distribution."""
        return self.rsample_and_log_prob(sample_shape)[0]


class BaggedFlow(Guide):
    def __init__(
        self,
        device,
        dataset,
        num_subsamples: int = 5,
        grad_clip: float = 1.0,
    ):
        """Initialize the bagged ensemble normalizing flow.

        Args:
            device (torch.device): Device to use for the guide.
            dataset (Dataset): Dataset object.
            num_subsamples (int): Number of subsamples to use for calibration.
            grad_clip (float): Gradient clipping value.
        """
        super().__init__(grad_clip=grad_clip)
        self.device = device
        self.num_subsamples = num_subsamples
        self.model = zuko.flows.NSF(
            features=dataset.latent_dims,
            context=num_subsamples,
            hidden_features=(64, 64),
        ).to(device)
        self.target_subsamples = None  # a place to store the target subsample indices
        self.labels = torch.eye(num_subsamples).to(device)

    def flow(self, context):
        return self.model(context)

    def loss(
        self,
        dataset: Dataset,
        n_nominal: int,
        obs_nominal: torch.Tensor,
        n_target: int,
        obs_target: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """Perform a single training step.

        Args:
            n_nominal (int): Number of nominal observations.
            obs_nominal (torch.Tensor): Nominal observations.
            n_target (int): Number of target observations.
            obs_target (torch.Tensor): Target observations.

        Returns:
            loss: tensor with the loss value
            dict[str, float]: Dictionary with loss values to log.
        """
        # If we have not yet created the subsamples, or if we are resampling, do that
        if self.target_subsamples is None:
            self.target_subsamples = []

            for _ in range(0, self.num_subsamples):
                # Sample n_target samples with replacement
                self.target_subsamples.append(
                    torch.randint(n_target, (n_target,), device=self.device)
                )

        # Loss is the negative ELBO on nominal data plus the loss on each subsample
        nominal_loss = -dataset.single_particle_elbo(
            self.nominal_distribution(), n_nominal, obs_nominal
        )

        calibration_loss = torch.tensor(0.0).to(self.device)
        for i, subsample in enumerate(self.target_subsamples):
            target_subsample_loss = -dataset.single_particle_elbo(
                self.flow(self.labels[i]), len(subsample), obs_target[subsample]
            )
            calibration_loss += target_subsample_loss / self.num_subsamples

        loss = nominal_loss + calibration_loss
        return loss, {
            "nominal_loss": nominal_loss.detach().cpu().item(),
            "calibration_loss": calibration_loss.detach().cpu().item(),
            "loss": loss.detach().cpu().item(),
        }

    def nominal_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for nominal data."""
        return self.flow(torch.zeros(self.num_subsamples, device=self.device))

    def target_distribution(self) -> torch.distributions.Distribution:
        """Return the distribution over the latent variables for target data."""
        return FlowEnsembleDistribution(
            self.flow,
            self.labels,
            torch.ones(self.num_subsamples, device=self.device) / self.num_subsamples,
            torch.tensor(1.0, device=self.device),
        )
