from tqdm import tqdm
import torch
from utils import torch_dct  


@torch.no_grad()
def per_mode_dct_mean_var(
    loader,
    upsample=None,          # callable or None
    device="cuda",
    norm="ortho",
    unbiased=False,
):
    mean = None
    M2 = None
    n = 0

    for imgs, _ in tqdm(loader, desc="Computing DCT statistics", leave=True):
        imgs = imgs.to(device)               # (B, 1, H, W)

        if upsample is not None:
            imgs = upsample(imgs)            # (B, 1, H', W')

        coeffs = torch_dct.dct_2d(imgs, norm=norm)  # (B, 1, H', W')

        B = coeffs.shape[0]
        batch_mean = coeffs.mean(dim=0)      # (1, H', W')
        batch_M2 = ((coeffs - batch_mean) ** 2).sum(dim=0)

        if mean is None:
            mean = batch_mean.clone()
            M2 = batch_M2.clone()
            n = B
        else:
            n_new = n + B
            delta = batch_mean - mean
            mean += delta * (B / n_new)
            M2 += batch_M2 + delta**2 * (n * B / n_new)
            n = n_new

    if unbiased:
        var = M2 / (n - 1)
    else:
        var = M2 / n

    return mean, var, n