import os
from typing import List, Optional, Union
import json

import random
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from generation import *
from pyramids import GaussianPyramid, IdentityPyramid
from torch.nn import functional as F
from tqdm.auto import tqdm
import torch.nn as nn
from nn_baselines.src.training_utils import load_model
from config import get_unet_config

from src.utils import (
    SoftmaxWeightedAverage,
    all_translations,
    SoftmaxWeightedAverageLLt,
)


class DenoisingPyramid(torch.nn.Module):
    def __init__(
        self,
        resolution: int,
        device: str,
        num_steps: int,
        kernel_size: Union[
            int, List[int]
        ],  # can define different kernel size for each level of the pyramid
        stride: int,
        kernel_overlap: float,  # percantage of the kernel size for each level of the pyramid
        temperature: float,
        sigma_correction: bool,
        denoiser: str = "knn",  # "knn" or "vanilla"
        num_levels: int = -1,
        pyramid_sigma: float = 1.0,
        pyramid_kernel_size: int = 5,
        pyramid_downscale_factor: int = 2,
        level_mixture_alpha: float = 0.5,
        latent_diffusion: bool = False,
        denoiser_args: dict = {},
        aggregation_mode: str = "mean",
        random_padding: bool = False,
        fill_in_zeros_in_x0: bool = False,
        embed_w: float = 1.0,
        save_dir: Optional[str] = None,
        save_prefix: str = "",
        stride_gen: Optional[Union[int, List[int]]] = None,
        beta_1: float = 0.0001,
        beta_T: float = 0.02,
        pyramidClass: torch.nn.Module = GaussianPyramid,
        in_channels: int = 3,
        **kwargs,
    ):
        super().__init__()

        self.device = device
        self.latent_diffusion = latent_diffusion
        self.n_channels = in_channels
        self.img_resolution = resolution
        if latent_diffusion:
            self.vae = (
                AutoencoderKL.from_pretrained(
                    "CompVis/stable-diffusion-v1-4", subfolder="vae"
                )
                .to(self.device)
                .eval()
            )
            dummy_enc = self.vae.encode(
                torch.zeros(1, 3, resolution, resolution, device=self.device)
            ).latent_dist.sample()
            self.resolution = dummy_enc.shape[-1]
            self.n_channels = dummy_enc.shape[1]
        else:
            self.resolution = resolution

        self.setup_denoiser = None
        self.num_steps = num_steps
        self.stride = stride
        self.temperature = temperature
        self.sigma_correction = sigma_correction
        self.random_padding = random_padding
        if num_levels == -1:
            if isinstance(kernel_size, int):
                raise ValueError(
                    "Either num_levels or pyramid_resolutions must be provided"
                )
            num_levels = len(kernel_size)
        self.num_levels = num_levels

        self.pyramid = pyramidClass(
            kernel_size=pyramid_kernel_size,
            kernel_sigma=pyramid_sigma,
            num_levels=num_levels,
            grayscale=self.n_channels == 1,
            device=device,
            resolution=self.resolution,
            n_channels=self.n_channels,
            downscale_factor=pyramid_downscale_factor,
        )
        dummy_img = torch.randn(
            1, self.n_channels, self.resolution, self.resolution, device=device
        )
        dummy_pyramid = self.pyramid.encode(dummy_img)
        self.pyramid_resolutions = [
            level[0].shape[-1] for level in dummy_pyramid.levels
        ]

        if isinstance(kernel_size, int):
            self.kernel_size = [kernel_size] * num_levels
        else:
            self.kernel_size = kernel_size
        self.kernel_size = [
            ks if ks != -1 else int(self.pyramid_resolutions[i])
            for i, ks in enumerate(self.kernel_size)
        ]

        if stride_gen is None:
            self.stride_gen = [
                int(kernel_size * (1 - kernel_overlap))
                for kernel_size in self.kernel_size
            ]
        elif isinstance(stride_gen, int):
            self.stride_gen = [stride_gen] * num_levels
        else:
            self.stride_gen = stride_gen
        for i in range(len(self.stride_gen)):  # Remove -1 from stride_gen
            if self.stride_gen[i] == -1:
                self.stride_gen[i] = self.kernel_size[i]

        self.DenoiserClass = (
            KNNOptimalDenoiser
            if denoiser == "knn"
            else EncoderOptimalDenoiser
            if denoiser == "encoder"
            else KNNCondSingleIndexDenoiser
            if denoiser == "cond_single_index"
            else KNNCondDenoiser
            if denoiser == "cond"
            else OptimalDenoiser
        )
        self.denoiser = denoiser
        self.denoiser_args = denoiser_args
        self.fill_in_zeros_in_x0 = fill_in_zeros_in_x0
        self.level_mixture_alpha = level_mixture_alpha
        self.scheduler = DDIMScheduler(
            beta_start=beta_1,
            beta_end=beta_T,
            beta_schedule="linear",
            prediction_type="epsilon",
        )
        self.scheduler.set_timesteps(num_steps)
        self.aggregation_mode = aggregation_mode
        self.pyramid_of_denoisers = []
        self.embed_w = embed_w
        self.save_dir = save_dir
        self.save_prefix = save_prefix

        # print(
        #     "\n",
        #     "=" * 100,
        #     f"DenoisingPyramid initialized with {self.num_steps} steps: \n"
        #     f"Pyramid resolution: {self.pyramid_resolutions}\n"
        #     f"Kernel size: {self.kernel_size}\n"
        #     f"Stride gen: {self.stride_gen}\n"
        # )

    @torch.no_grad()
    def _image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
        """Preprocess input images to match expected format and channels.

        Args:
            img: Input tensor of shape [B, C, H, W]

        Returns:
            Preprocessed tensor of shape [B, self.n_channels, H, W] normalized to [-1, 1]
        """
        # Handle both RGB and grayscale inputs
        if img.shape[1] != self.n_channels:
            if self.n_channels == 1:
                if img.shape[1] == 3:
                    # Convert RGB to grayscale using ITU-R BT.601 coefficients
                    img = (
                        0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
                    )
            else:  # self.n_channels == 3
                if img.shape[1] == 1:
                    # Repeat grayscale to RGB
                    img = img.repeat(1, 3, 1, 1)
                else:
                    raise ValueError(
                        f"Unsupported number of channels: {img.shape[1]}. Expected 1 or 3."
                    )

        # Resize to target resolution
        if img.shape[-2:] != (self.img_resolution, self.img_resolution):
            img = torch.nn.functional.interpolate(
                img.to(self.device),
                size=(self.img_resolution, self.img_resolution),
                mode="bilinear",
                align_corners=False,
            )
        else:
            img = img.to(self.device)

        # Normalize to [-1, 1] range
        if img.min() >= 0 and img.max() <= 1:
            img = (img - 0.5) * 2
        elif img.min() >= -1 and img.max() <= 1:
            pass  # Already in [-1, 1] range
        else:
            # Assume arbitrary range, normalize to [-1, 1]
            img = 2 * (img - img.min()) / (img.max() - img.min()) - 1

        if self.latent_diffusion:
            img = (
                self.vae.config.scaling_factor
                * self.vae.encode(img).latent_dist.sample()
            )
        return img

    @torch.no_grad()
    def _image_postprocess(self, img: torch.Tensor) -> torch.Tensor:
        """Postprocess output images back to original format.

        Args:
            img: Input tensor of shape [B, self.n_channels, H, W] in [-1, 1] range

        Returns:
            Postprocessed tensor of shape [B, C, H, W] in [0, 1] range
        """
        if self.latent_diffusion:
            img = img / self.vae.config.scaling_factor
            img = self.vae.decode(img).sample

        # Convert from [-1, 1] to [0, 1] range
        img = (img + 1) / 2

        # Clamp to ensure valid range
        img = torch.clamp(img, 0, 1)

        return img.cpu()

    def _upscale(self, image: torch.Tensor, to_resolution: int) -> torch.Tensor:
        return F.interpolate(
            image,
            size=(to_resolution, to_resolution),
            mode="bilinear",
            align_corners=False,
        )

    def _embed_w(
        self, patches: torch.Tensor, low_freq_patches: torch.Tensor
    ) -> torch.Tensor:
        return torch.cat([patches, self.embed_w * low_freq_patches], dim=1)

    def _get_denoiser_path(self, level_idx: int) -> str:
        """Get the path for saving/loading a denoiser at a specific level"""
        if self.save_dir is None:
            raise ValueError("save_dir must be set to save/load denoisers")

        denoiser_name = self.denoiser
        if self.denoiser == "cond_single_index":
            denoiser_name = "knn"

        return os.path.join(
            self.save_dir,
            f"ours/{self.save_prefix}",
            f"{self.save_prefix}_{denoiser_name}_denoiser_level_{level_idx}_k{self.kernel_size[level_idx]}_s{self.stride}_embed_w{self.embed_w:.2f}",
        )

    def train(self, dataloader: torch.utils.data.DataLoader):
        """Train the model using a dataloader.

        Args:
            dataloader: A DataLoader providing batches of images normalized to [0, 1]
        """
        # Initialize list of denoisers
        self.pyramid_of_denoisers = []

        # Try to load existing denoisers first
        if self.save_dir:
            print("Checking for pre-trained denoisers...")
            for level_idx in range(self.pyramid.num_levels):
                for try_downloading_stride in [1, 2]:
                    self.stride = try_downloading_stride
                    denoiser_path = self._get_denoiser_path(level_idx) + ".data"
                    if os.path.exists(denoiser_path):
                        print(
                            f"Loading pre-trained denoiser for level {level_idx} from {denoiser_path}"
                        )
                        denoiser = self.DenoiserClass(
                            None,
                            self.scheduler,
                            temperature=self.temperature,
                            level_idx=level_idx,
                            **self.denoiser_args,
                        )
                        if (
                            self.setup_denoiser is not None
                        ):  # callback to setup a callback in the denoiser, its complicated
                            self.setup_denoiser(denoiser, level_idx)
                        denoiser.load(self._get_denoiser_path(level_idx))
                        break
                    else:
                        print(
                            f"No pre-trained denoiser found for level {level_idx} in {denoiser_path}"
                        )
                        denoiser = self.DenoiserClass(
                            None,
                            self.scheduler,
                            temperature=self.temperature,
                            level_idx=level_idx,
                            **self.denoiser_args,
                        )
                        if (
                            self.setup_denoiser is not None
                        ):  # callback to setup a callback in the denoiser, its complicated
                            self.setup_denoiser(denoiser, level_idx)
                self.pyramid_of_denoisers.append(denoiser)

            # If all denoisers were loaded, we can return early
            if all(denoiser.is_trained() for denoiser in self.pyramid_of_denoisers):
                print("All denoisers loaded successfully!")
                return
        else:
            # If no save_dir, initialize all denoisers from scratch
            self.pyramid_of_denoisers = []
            for level_idx in range(self.pyramid.num_levels):
                d = self.DenoiserClass(
                    None,
                    self.scheduler,
                    temperature=self.temperature,
                    level_idx=level_idx,
                    **self.denoiser_args,
                )
                if (
                    self.setup_denoiser is not None
                ):  # callback to setup a callback in the denoiser, its complicated
                    self.setup_denoiser(d, level_idx)
                self.pyramid_of_denoisers.append(d)

        print(dataloader)
        # Train only the denoisers that haven't been loaded
        for full_batch in tqdm(dataloader, desc="Adding data to denoisers"):
            batch = full_batch[0]  # CIFAR10 returns (image, label) pairs
            img_batch = self._image_preprocess(batch)
            img_latent = self.pyramid.encode(img_batch)

            for level_idx, level in enumerate(img_latent.levels):
                if not self.pyramid_of_denoisers[level_idx].is_trained():
                    patches = get_patches(
                        level,
                        kernel_size=self.kernel_size[level_idx],
                        stride=self.stride,
                        n_channels=self.n_channels,
                    )
                    if self.embed_w > 0 and level_idx < self.pyramid.num_levels - 1:
                        low_freq_patches = get_patches(
                            [
                                self._upscale(
                                    self.pyramid.downscale(l),
                                    l.shape[-1],
                                )
                                for l in level
                            ],
                            kernel_size=self.kernel_size[level_idx],
                            stride=self.stride,
                            n_channels=self.n_channels,
                        )
                        patches = self._embed_w(patches, low_freq_patches)
                    self.pyramid_of_denoisers[level_idx].add_data(patches)

        # Train only untrained denoisers
        for level_idx, denoiser in enumerate(self.pyramid_of_denoisers):
            if not denoiser.is_trained():
                print(f"Training denoiser for level {level_idx}")
                denoiser.train()

        # Save all denoisers (both newly trained and previously loaded ones)
        if self.save_dir:
            os.makedirs(self.save_dir, exist_ok=True)
            self.save_denoisers()

    def save_denoisers(self):
        """Save all denoisers to disk"""
        print("Saving trained denoisers...")
        for level_idx, denoiser in tqdm(
            enumerate(self.pyramid_of_denoisers), desc="Saving denoisers"
        ):
            path = self._get_denoiser_path(level_idx)
            denoiser.save(path)

    def load_denoisers(self):
        """Load all denoisers from disk"""
        self.pyramid_of_denoisers = []
        for level_idx in range(self.pyramid.num_levels):
            path = self._get_denoiser_path(level_idx)
            denoiser = self.DenoiserClass(
                None, self.scheduler, temperature=self.temperature, **self.denoiser_args
            )
            denoiser.load(path)
            self.pyramid_of_denoisers.append(denoiser)

    def _fold(self, predicted_patch_noise, level_idx, padding, level_img_shape):
        if self.aggregation_mode == "mean":
            fold_func = fold_mean
        elif self.aggregation_mode == "sum":
            fold_func = fold_sum
        elif self.aggregation_mode == "median":
            fold_func = fold_median
        elif self.aggregation_mode == "center":
            if level_idx == self.pyramid.num_levels - 1:
                fold_func = fold_mean
            else:
                fold_func = fold_center
        else:
            raise ValueError(f"Invalid aggregation mode: {self.aggregation_mode}")

        return fold_func(
            predicted_patch_noise,
            self.kernel_size[level_idx],
            self.stride_gen[level_idx],
            padding,
            level_img_shape,
        )

    @torch.no_grad()
    def denoise(
        self,
        img: torch.Tensor,
        timestep: int,
    ):
        cur_pyramid = self.pyramid.encode(img)

        level_imgs = []
        level_x0s = []
        for level_idx in range(self.pyramid.num_levels - 1, -1, -1):
            level_img = cur_pyramid.levels[level_idx][0]
            padding = 0

            cur_patches = get_patches(
                [level_img],
                kernel_size=self.kernel_size[level_idx],
                stride=self.stride_gen[level_idx],
                n_channels=self.n_channels,
                padding=padding,
            )

            # Get model output (denoised prediction)
            # [num_patches * B, C, kernel_size, kernel_size]
            if self.embed_w > 0 and level_idx < self.pyramid.num_levels - 1:
                low_freq_patches = get_patches(
                    [self._upscale(level_x0s[-1], level_img.shape[-1])],
                    kernel_size=self.kernel_size[level_idx],
                    stride=self.stride_gen[level_idx],
                    n_channels=self.n_channels,
                    padding=padding,
                )
                # NOTE: we need to scale the current patches by the alpha_cumprod because
                # the search algo assumes that signal is downscaled by alpha_cumprod
                cur_patches = self._embed_w(
                    cur_patches,
                    low_freq_patches,
                )

            predicted_patch_noise = self.pyramid_of_denoisers[level_idx](
                cur_patches,
                timestep,
                noisy_subspace=self.n_channels,
                res_scale=(self.resolution / self.pyramid_resolutions[level_idx]) ** 2,
            )
            if self.embed_w > 0 and level_idx < self.pyramid.num_levels - 1:
                predicted_patch_noise = predicted_patch_noise[:, : self.n_channels, ...]

            predicted_level_noise = self._fold(
                predicted_patch_noise,
                level_idx,
                padding,
                level_img.shape,
            )

            level_x0 = predicted_level_noise

            if self.fill_in_zeros_in_x0 and (level_idx < self.pyramid.num_levels - 1):
                level_x0 = torch.where(
                    predicted_level_noise == 0,
                    self._upscale(total_x0, self.pyramid_resolutions[level_idx]),
                    level_x0,
                )

            level_imgs.append(level_img)
            level_x0s.append(level_x0)

            if level_idx < self.pyramid.num_levels - 1:
                total_x0 = level_x0 + self._upscale(
                    total_x0 - self.pyramid.downscale(level_x0),
                    self.pyramid_resolutions[level_idx],
                )
            else:
                total_x0 = level_x0.clone()
        # total_x0 = level_x0s[-1]

        predicted_noise = get_noise_from_target(self.scheduler, total_x0, img, timestep)

        # predicted_noise = predicted_level_noise
        # Step the scheduler

        return (
            predicted_noise,
            level_imgs,
            level_x0s,
            total_x0,
        )

    def _fix_all_seed(self, seed):
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    @torch.no_grad()
    def sample(
        self,
        batch_size: int = 1,
        return_trajectory: bool = False,
        seed: int = 42,
    ):
        self._fix_all_seed(seed)
        cur_img = torch.randn(
            batch_size,
            self.n_channels,
            self.resolution,
            self.resolution,
            device=self.device,
        )

        noise = cur_img.clone()
        cur_img = cur_img * self.scheduler.init_noise_sigma

        trajectory_noisy = []
        trajectory_eps = []
        trajectory_x0 = []
        trajectory_level_imgs = []
        trajectory_level_x0s = []

        for timestep in tqdm(self.scheduler.timesteps, desc="Sampling"):
            predicted_noise, level_imgs, level_x0s, total_x0 = self.denoise(
                cur_img, timestep
            )

            if return_trajectory:
                trajectory_noisy.append(cur_img.clone())
                trajectory_eps.append(predicted_noise.clone())
                trajectory_x0.append(total_x0.clone())
                trajectory_level_imgs.append(level_imgs)
                trajectory_level_x0s.append(level_x0s)

            cur_img = self.scheduler.step(
                model_output=predicted_noise,
                timestep=timestep,
                sample=cur_img,
                generator=None,
            ).prev_sample

        if return_trajectory:
            return (
                self._image_postprocess(cur_img),
                noise,
                trajectory_noisy,
                trajectory_eps,
                trajectory_x0,
                trajectory_level_imgs,
                trajectory_level_x0s,
            )
        else:
            return self._image_postprocess(cur_img), noise


def get_s_matrix_path(
    dataset_name: str, num_images: int = -1, kernel_size: int = None
) -> str:
    """Get the path to the S-matrix file for a given dataset and kernel size."""
    # Map datasets to their native kernel sizes
    dataset_kernel_sizes = {
        "mnist": 28,
        "fashion_mnist": 28,
        "cifar10": 32,
        "ffhq": 64,
        "celeba_hq": 64,
        "afhq": 64,
    }

    # Use dataset's native kernel size if none specified
    if kernel_size is None:
        kernel_size = dataset_kernel_sizes[dataset_name]

    # Format number of images string
    num_images_str = "full" if num_images == -1 else str(num_images)

    base_path = f"data/{dataset_name}_{num_images_str}"
    return (
        base_path + f"_s_matrix_ks{kernel_size}.pt",
        base_path + f"_mean_ks{kernel_size}.pt",
    )


class BaseDenoiser(torch.nn.Module):
    def __init__(
        self,
        resolution: int,
        device: str,
        num_steps: int,
        *args,
        beta_1: float = 0.0001,
        beta_T: float = 0.02,
        dataset_name: str = "cifar10",
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.n_channels = kwargs.get("in_channels", 3)
        self.img_resolution = resolution
        self.resolution = resolution
        self.dataset_name = dataset_name

        self.scheduler = DDIMScheduler(
            beta_start=beta_1,
            beta_end=beta_T,
            beta_schedule="linear",
            prediction_type="epsilon",
        )
        self.scheduler.set_timesteps(num_steps)
        self.num_steps = num_steps

    sample = DenoisingPyramid.sample
    _fix_all_seed = DenoisingPyramid._fix_all_seed

    def train(self, dataloader: torch.utils.data.DataLoader):
        self.dataloader = dataloader
        pass

    @torch.no_grad()
    def _image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
        imgs = torch.nn.functional.interpolate(
            img[:, :3, ...].to(self.device),
            size=(self.img_resolution, self.img_resolution),
            mode="bilinear",
            align_corners=False,
        )
        img_rescalsed = (imgs - 0.5) * 2  # Normalize to [-1, 1]
        return img_rescalsed

    @torch.no_grad()
    def _image_postprocess(self, img: torch.Tensor) -> torch.Tensor:
        img_rescaled = (img + 1) / 2  # Normalize to [0, 1]
        return img_rescaled.clamp(0, 1)


class KambWithWienerBasedPatches(BaseDenoiser):
    def __init__(
        self,
        *args,
        save_prefix: str = "full",
        mask_threshold: float = 0.02,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.mask_threshold = mask_threshold

        # Load appropriate S matrix
        # Handle dataset names that may contain underscores (e.g. fashion_mnist)
        parts = save_prefix.split("_")
        num_images = parts[-1]  # Take the last part which should be 'full' or a number
        if num_images == "full":
            num_images = -1
        else:
            num_images = int(num_images)

        S_path, mean_path = get_s_matrix_path(self.dataset_name, num_images)

        if not os.path.exists(S_path) or not os.path.exists(mean_path):
            raise ValueError(
                f"S matrix or mean not found at {S_path} or {mean_path}. Please run s_matrix_calculator.py first."
            )

        S = torch.load(S_path, weights_only=True).to(self.device)
        self.mean = torch.load(mean_path, weights_only=True).to(self.device)
        self.U, self.LA, self.Vh = torch.linalg.svd(S)

    def _get_Lt_Ht(self, timestep: int, cov_matrix=None) -> torch.Tensor:
        if cov_matrix is None:
            U, LA, Vh = self.U, self.LA, self.Vh
        else:
            U, LA, Vh = torch.linalg.svd(cov_matrix)

        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        beta_prod_t = 1 - alpha_prod_t

        # Reconstruct Lt using SVD components
        shrink_factors = alpha_prod_t * LA / (beta_prod_t + alpha_prod_t * LA)
        # shrink_factors = (shrink_factors > 0.2).float()
        LAshrink = torch.diag_embed(shrink_factors)
        LLt = torch.bmm(U, torch.bmm(LAshrink, Vh))

        I = torch.eye(LLt.shape[1], device=LLt.device)
        I = I.expand(LLt.shape[0], -1, -1)
        Ht = I - LLt.clone()
        Lt = LLt.clone() / torch.sqrt(alpha_prod_t)
        # print("Alpha_t: ", alpha_prod_t, "Beta_t: ", beta_prod_t)

        return Lt, Ht, LLt

    @torch.no_grad()
    def denoise(
        self,
        img: torch.Tensor,  # [b, n] flattened noisy inputs
        timestep: int,
    ):
        # 1) get LLᵀ for this timestep
        Lt, Ht, LLt = self._get_Lt_Ht(timestep)  # LLt: [n, n]
        img_flat = img.view(img.shape[0], -1)

        # flatten/reshape inputs
        b, n = img_flat.shape
        xt = img_flat.to(self.device)  # [b, n]

        # move constants to device
        device = xt.device
        LLt = LLt.to(device)[0]
        # LLt = torch.ones_like(LLt)
        # precompute (LLt)^2 for the quadratic form

        LLt_sq = LLt
        denom = torch.diagonal(LLt_sq).unsqueeze(1)
        denom[denom < 1e-6] = 1.0
        LLt_sq = LLt_sq / denom
        LLt_sq[LLt_sq < torch.max(LLt_sq) * self.mask_threshold] = 0.0
        LLt_sq[LLt_sq > 0] = 1.0
        # LLt_sq = LLt_sq.abs().pow(self.mask_threshold)  # [n, n]
        # LLt_sq = LLt_sq / torch.diagonal(LLt_sq).unsqueeze(1)
        # LLt_sq = LLt_sq /

        firstMoment = SoftmaxWeightedAverage(mode="images", device=device)
        secondMoment = SoftmaxWeightedAverage(mode="covariance", device=device)

        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        beta_prod_t = 1 - alpha_prod_t

        # aug_size = int(((LLt_sq < 1e-4).sum() / n**2) ** (0.5) * n) // 2
        # print(f"\nAugmentation size: {aug_size}\n")
        aug_size = 0

        # loop over mini‐batches of your dataset
        for x0_batch in tqdm(self.dataloader, desc="Dataloader", leave=False):
            # [k/a, c, w, h] -> [4a^2 * k, c, w, h]
            translations = all_translations(x0_batch[0], aug_size)
            for x0_batch_aug in tqdm(
                translations,
                desc="Augmenting",
                leave=False,
                total=(2 * aug_size + 1) ** 2,
            ):
                # [k_batch, n]
                x0b = x0_batch_aug.reshape(-1, n).to(device)

                # [b, 1, n] - [1, k_batch, n] -> [b, k_batch, n]
                delta = (
                    xt.unsqueeze(1) - torch.sqrt(alpha_prod_t) * x0b.unsqueeze(0)
                ) ** 2

                #  ds_chunk[i, j, c] = sum_r delta_sq[i, j, r] * LLt_sq[r, c]
                ds_chunk = torch.einsum("ijk,kl->ijl", delta, LLt_sq)  # [b, k_batch, n]
                logits = -ds_chunk / 2 / beta_prod_t

                firstMoment.add(x0b, logits)
                # secondMoment.add(x0b.unsqueeze(-1) @ x0b.unsqueeze(1), logits)

        # cov_mat = secondMoment.get_average()
        # Lt, Ht, LLt = self._get_Lt_Ht(timestep, cov_matrix=cov_mat)
        x0_mean = firstMoment.get_average()
        target_x0 = x0_mean

        # target_x0 = (LLt @ Lt @ img_flat.unsqueeze(-1) + Ht @ x0_mean.unsqueeze(-1)).squeeze(-1)

        # img_flat = img.flatten(start_dim=1).unsqueeze(-1)
        # target_x0 = (
        #     (Lt @ img_flat) +
        #     (Ht @ x0_mean.unsqueeze(-1))
        # ).squeeze(-1)

        target_x0 = target_x0.view(img.shape)
        predicted_noise = get_noise_from_target(
            self.scheduler, target_x0, img, timestep
        )

        return [predicted_noise, target_x0, target_x0, target_x0]


class DenoisingKamb(DenoisingPyramid):
    def __init__(
        self,
        *args,
        simplify_schedule: bool = True,
        **kwargs,
    ):
        kernel_size = kwargs["kernel_size"]
        self.ks_schedule = kernel_size

        if simplify_schedule:
            kwargs["kernel_size"] = list(set(kernel_size))
        super().__init__(*args, pyramidClass=IdentityPyramid, **kwargs)

    def _get_denoiser_path(self, level_idx: int) -> str:
        """Get the path for saving/loading a denoiser at a specific level"""
        if self.save_dir is None:
            raise ValueError("save_dir must be set to save/load denoisers")

        denoiser_name = self.denoiser
        if self.denoiser == "cond_single_index":
            denoiser_name = "knn"
        return os.path.join(
            self.save_dir,
            f"ours/{self.save_prefix}",
            f"{self.save_prefix}_{denoiser_name}_denoiser_level_{0}_k{self.kernel_size[level_idx]}_s{self.stride}_embed_w{self.embed_w:.2f}",
        )

    @torch.no_grad()
    def denoise(
        self,
        img: torch.Tensor,
        timestep: int,
    ):
        self.denoisers_kernels = self.kernel_size
        self.kernel_size = self.ks_schedule

        padding = 0
        step_size = 1000 // len(self.ks_schedule)
        level_idx = (999 - timestep) // step_size
        level_idx = min(level_idx, len(self.ks_schedule) - 1)
        level_idx = max(level_idx, 0)
        # print(timestep, step_size, level_idx)

        cur_patches = get_patches(
            [img],
            kernel_size=self.ks_schedule[level_idx],
            stride=self.stride_gen[level_idx],
            n_channels=self.n_channels,
        )

        # Find which unique kernel size matches the current ks_schedule value
        denoiser_idx = torch.where(
            torch.tensor(self.denoisers_kernels) == self.ks_schedule[level_idx]
        )[0].item()

        predicted_patch_noise = self.pyramid_of_denoisers[denoiser_idx](
            cur_patches,
            timestep,
            # res_scale=(32 / self.ks_schedule[level_idx]) ** 2,
        )

        total_x0 = self._fold(
            predicted_patch_noise,
            level_idx,
            padding,
            img.shape,
        )
        predicted_noise = get_noise_from_target(self.scheduler, total_x0, img, timestep)

        self.kernel_size = self.denoisers_kernels

        return (
            predicted_noise,
            total_x0,
            total_x0,
            total_x0,
        )



# template for Wiener-based denoising
class DenoisingWiener(DenoisingKamb):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        # Load appropriate S matrix based on dataset and resolution
        dataset_name = kwargs.get("dataset_name", "cifar10")
        num_images = kwargs.get("save_prefix", "full").split("_")[-1]
        if num_images == "full":
            num_images = -1
        else:
            num_images = int(num_images)

        S_path, mean_path = get_s_matrix_path(dataset_name, num_images)

        if not os.path.exists(S_path):
            raise ValueError(
                f"S matrix not found at {S_path}. Please run s_matrix_calculator.py first."
            )

        S = torch.load(S_path, weights_only=True).to(self.device)
        self.mean = torch.load(mean_path, weights_only=True).to(self.device)[0]
        self.U, self.LA, self.Vh = torch.linalg.svd(S)

    def _get_Lt_Ht(self, timestep: int) -> torch.Tensor:
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        beta_prod_t = 1 - alpha_prod_t

        # Reconstruct Lt using SVD components
        shrink_factors = alpha_prod_t * self.LA / (beta_prod_t + alpha_prod_t * self.LA)
        LAshrink = torch.diag_embed(shrink_factors)
        LLt = torch.bmm(self.U, torch.bmm(LAshrink, self.Vh))

        I = torch.eye(LLt.shape[1], device=LLt.device)
        I = I.expand(LLt.shape[0], -1, -1)
        Ht = I - LLt
        Lt = LLt.clone() / torch.sqrt(alpha_prod_t)

        return Lt, Ht, LLt

    @torch.no_grad()
    def denoise(
        self,
        img: torch.Tensor,
        timestep: int,
    ):
        self.denoisers_kernels = self.kernel_size
        self.kernel_size = self.ks_schedule
        Lt, Ht, LLt = self._get_Lt_Ht(timestep)

        lx0 = (Lt @ img.flatten(start_dim=1).unsqueeze(-1)).squeeze(-1).view_as(img)
        mean_term = (Ht @ self.mean).view(1, img.shape[1], img.shape[2], img.shape[3])
        total_x0 = lx0 + mean_term

        predicted_noise = get_noise_from_target(self.scheduler, total_x0, img, timestep)

        self.kernel_size = self.denoisers_kernels
        return (
            predicted_noise,
            lx0,
            total_x0,
            total_x0,
        )



class AnotherUnet(BaseDenoiser):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Get dataset name and num_images from save_prefix (e.g. "cifar10_full" or "mnist_1000")
        dataset_name = kwargs.get("dataset_name", "cifar10")
        save_prefix = kwargs.get("save_prefix", "cifar10_full")
        num_images = save_prefix.split("_")[-1]
        num_images = -1 if num_images == "full" else int(num_images)

        # Get the proper UNet config for this dataset
        self.config = get_unet_config(dataset_name, num_images)

        # Override device from parent class
        self.config["device"] = self.device

        # Override with alternate checkpoint paths from comments
        alternate_paths = {
            "mnist": {
                -1: (
                    "trained_models/unet/unet_mnist_-1_noattn_20250513_195201",
                    "ckpt_epoch_200.pt",
                ),
            },
            "fashion_mnist": {
                -1: (
                    "trained_models/unet/unet_fashion_mnist_-1_noattn_20250514_001525",
                    "ckpt_epoch_200.pt",
                ),
            },
            "celeba_hq": {
                -1: (
                    "trained_models/unet/unet_celeba_hq_-1_noattn_20250514_030841",
                    "ckpt_epoch_200.pt",
                ),
            },
            "afhq": {
                -1: (
                    "trained_models/unet/unet_afhq_-1_noattn_20250515_004233",
                    "ckpt_epoch_200.pt",
                ),
            },
            "cifar10": {
                -1: (
                    "trained_models/unet/unet_cifar10_-1_noattn_20250313_175115",
                    "ckpt_epoch_180.pt",
                ),
            },
        }

        # Update config with alternate path if available
        if (
            dataset_name in alternate_paths
            and num_images in alternate_paths[dataset_name]
        ):
            save_weight_dir, test_load_weight = alternate_paths[dataset_name][
                num_images
            ]
            self.config["save_weight_dir"] = save_weight_dir
            self.config["test_load_weight"] = test_load_weight
            self.config["training_load_weight"] = test_load_weight

        # Initialize the model with proper config
        self.model = load_model(self.config, self.device)
        self.model.eval()

    @torch.no_grad()
    def denoise(self, img, timestep):
        model_output = self.model(img, timestep.to(self.device)[None])

        # Step the scheduler
        step_output = self.scheduler.step(
            model_output=model_output,
            timestep=timestep,
            sample=img,
            generator=None,
        )
        pred_x0 = step_output.pred_original_sample

        return (
            model_output,
            pred_x0,
            pred_x0,
            pred_x0,
        )

