"""Guide for image generation based on Glow."""
import normflows as nf

from calnf.guides.glow import GlowDistribution
from calnf.guides.balanced_calnf import BalancedCalibratedFlow


class BalancedGlow(BalancedCalibratedFlow):
    """
    A normalizing flow for image generation based on the Glow architecture.

    Uses the balanced+calibrated normalizing flow training model

    Glow paper: https://arxiv.org/abs/1807.03039
    Reference implementation: https://github.com/VincentStimper/normalizing-flows

    Args:
        device: Device to use for the guide.
        dataset: Dataset object.
        input_shape: Shape of the input images (channel, width, height).
        num_subsamples: Number of subsamples to use for calibration.
        reg_penalty: the strength of the regularization term.
        calibrate_label: if False, do not update the calibration label
        hidden_channels: Number of hidden channels in the flow.
        L: Number of layers.
        K: Number of blocks per layer.
        num_kl_particles: Number of particles to use for the KL divergence.
        grad_clip: Gradient clipping value.
    """

    def __init__(
        self,
        device,
        dataset,
        input_shape,
        num_subsamples: int = 5,
        reg_penalty: float = 1.0,
        calibrate_label: bool = True,
        hidden_channels=128,
        L=2,
        K=8,
        num_kl_particles=5,
        grad_clip=1.0,
    ):
        super().__init__(
            device,
            dataset,
            num_subsamples=num_subsamples,
            reg_penalty=reg_penalty,
            calibrate_label=calibrate_label,
            num_kl_particles=num_kl_particles,
            grad_clip=grad_clip,
        )
        self.device = device
        self.input_shape = input_shape

        # Define the flow
        channels = input_shape[0]
        split_mode = "channel"
        scale = True
        q0 = []
        merges = []
        flows = []
        for i in range(L):
            flows_ = []
            for j in range(K):
                flows_ += [
                    nf.flows.GlowBlock(
                        channels * 2 ** (L + 1 - i),
                        hidden_channels,
                        split_mode=split_mode,
                        scale=scale,
                    )
                ]
            flows_ += [nf.flows.Squeeze()]
            flows += [flows_]
            if i > 0:
                merges += [nf.flows.Merge()]
                latent_shape = (
                    input_shape[0] * 2 ** (L - i),
                    input_shape[1] // 2 ** (L - i),
                    input_shape[2] // 2 ** (L - i),
                )
            else:
                latent_shape = (
                    input_shape[0] * 2 ** (L + 1),
                    input_shape[1] // 2**L,
                    input_shape[2] // 2**L,
                )
            # 1 class, where we set it to be 0 for nominal and 1 for target
            q0 += [nf.distributions.ClassCondDiagGaussian(latent_shape, num_subsamples)]

        self.model = nf.MultiscaleFlow(q0, flows, merges).to(self.device)

    def flow(self, context):
        """Return the flow for the given context."""
        return GlowDistribution(self.model, context.to(self.device), self.input_shape)
