# Code Appendix of Vector Quantization using Gaussian Variational Autoencoder

# Usage 
## Prequisites
* dependency in environment.yaml
    ```bash
    conda env create --file=environment.yaml
    conda activate tokenizer
    ```
## Installation
* from source
    ```bash
    pip install -e .
    ```
* [optional] CUDA kernel for fast GaussianQuant
    ```bash
    cd gq_cuda_extension
    pip install --no-build-isolation -e .
    ```

## Training a Vanilla Gaussian VAE
* modify the yaml file according to your system, pay special attention to "trainer-device", "trainer-num_nodes", "data-train-params-root"

    ```bash
    python main.py --config sd3unet_gaussian_kl_0.64.yaml --wandb
    ```

## Training a Gaussian VAE for Gaussian Quant Purpose
* using TDC constrait (Sec 3.3)
    ```bash
    python main.py --base configs/sd3unet_gq_0.25_train_tdc.yaml --wandb
    ```
* key code snippet: ./pit/quantization/gaussian.py line 72 class TargetAdaptativeGaussianRegularizer
    ```python
    # get mu, logvar, std, kl
    mu, logvar = z.chunk(2, 2)
    logvar = torch.clamp(logvar, self.logvar_range[0], self.logvar_range[1])
    std = torch.exp(0.5 * logvar)
    var = torch.exp(logvar)
    zhat = mu + torch.randn_like(mu) * std
    kls = 1.4426 * (
        0.5
        * (torch.pow(mu, 2) + var - 1.0 - logvar)
    )
    b, c, h, w = kls.shape
    # sum kl by group 
    kls = torch.sum(kls.reshape(b, self.group, c // self.group, h, w), dim=1)
    # evaluate weighted kl according to Eq. 32
    ge = (kls > self.target + self.tolerance).type(kls.dtype) * self.lam_max
    eq = (kls <= self.target + self.tolerance).type(kls.dtype) * (
        kls >= self.target - self.tolerance
    ).type(kls.dtype)
    le = (kls < self.target - self.tolerance).type(kls.dtype) * self.lam_min
    kl_loss = torch.sum((ge * kls + eq * kls + le * kls), dim=[1,2,3])
    kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
    # update lambda according to Eq. 33
    if torch.mean(kls) > self.target:
        self.lam = self.lam * self.lam_factor
    else:
        self.lam = self.lam / self.lam_factor
    if torch.max(kls) > self.target + self.tolerance:
        self.lam_max = self.lam_max * self.lam_factor
    else:
        self.lam_max / self.lam_max * self.lam_factor
    self.lam_max = max(min(self.lam_max, self.lam_range[1]), 1.0)
    if torch.min(kls) < self.target - self.tolerance:
        self.lam_min = self.lam_min / self.lam_factor
    else:
        self.lam_min = self.lam_min * self.lam_factor
    self.lam_min = max(min(self.lam_min, 1.0), self.lam_range[0])
    ```
* using SMKL constrait (Appendix B.2)
    ```bash
    python main.py --base configs/sd3unet_gq_0.25_train_smkl.yaml --wandb
    ```
* key code snippet: ./pit/quantization/gaussian.py line 284 class GroupedLambertWRegularizer
    ```python
    # obtain tau, gamma in Eq. 25
    tau, gamma_unnorm = torch.chunk(z, 2, dim=1)
    b, c, h, w = tau.shape
    assert(c % self.group == 0)
    ng = c // self.group
    tau = tau.reshape(b, self.group, ng, h, w)
    gamma_unnorm = gamma_unnorm.reshape(b, self.group, ng, h, w)
    gamma = F.softmax(gamma_unnorm, dim=1)
    # regularize r1=0.1 in Eq. 26
    kw = gamma * (self.target - 0.1) + 0.1 / self.group
    # compute mean with regularize r2=0.01 in Eq. 26
    mu = torch.sqrt(2 * kw) * F.tanh(tau) * (1 - 1e-2)
    # compute mean, std in Eq 25
    W = special.lambertw
    var = -W(-torch.exp(mu ** 2 - 2 * kw - 1.0))
    ```

## Inference a Gaussian VAE as a Gaussian Quant converted VQ-VAE
    ```bash
    python -m torch.distributed.launch --standalone --use-env \
    --nproc-per-node 8 eval.py \
    --bs 32 \
    --base configs/sd3unet_gq_0.25.yaml \
    --ckpt models_pretrain/sd3unet_gq_0.25.ckpt \
    --dataset $PATH_TO_DATASET_FOLDER
    ```
* key code snippet
    ```python
    # get mu, sigma, std, logvar, var
    mu, logvar = z.chunk(2, 2)
    logvar = torch.clamp(logvar, self.logvar_range[0], self.logvar_range[1])
    std = torch.exp(logvar * 0.5)
    mu = mu.reshape(b, l, self.group, c // self.group).permute(0,1,3,2).reshape(-1, self.group)
    std = std.reshape(b, l, self.group, c // self.group).permute(0,1,3,2).reshape(-1, self.group)
    # quantize with batch size reduced by 8 to save memory
    bs = mu.shape[0] // 8
    zhat = torch.zeros_like(mu)
    indices = torch.zeros([mu.shape[0]], device=mu.device, dtype=torch.long)
    for i in range(0, mu.shape[0], bs):
        mu_q = mu[i : i + bs]
        std_q = std[i:i+bs]
        q_normal_dist = Normal(mu_q[:, None, :], std_q[:, None, :])
        # actual quantization with grouping in Eq. 31 
        log_ratios = (
            q_normal_dist.log_prob(self.prior_samples[None])
            - self.normal_log_prob[None] * self.beta
        )
        perturbed = torch.sum(log_ratios, dim=2)
        argmax_indices = torch.argmax(perturbed, dim=1)
        zhat[i : i + bs] = torch.index_select(self.prior_samples, 0, argmax_indices)
        indices[i : i + bs] = argmax_indices
    zhat = zhat.reshape(b, l, c // self.group, self.group).permute(0, 1, 3, 2).reshape(b, l, c).float()
    indices = indices.reshape(b, l, c // self.group)
    ```