import math
import os
from itertools import product
from typing import Callable, Tuple, Union

import time
import faiss
import numpy as np
import torch
import torchvision
from folding_functions import *
from pyramids import Pyramid
from torch.nn import functional as F
from tqdm.auto import tqdm

import wandb


def get_noise_from_target(scheduler, target, cur_xt, t):
    alpha_prod_t = scheduler.alphas_cumprod[t]
    beta_prod_t = 1 - alpha_prod_t
    noise = (cur_xt - target * alpha_prod_t ** (0.5)) / (beta_prod_t ** (0.5))
    return noise


def get_patches(ref_imgs, kernel_size=5, stride=1, n_channels=3, padding=0):
    return torch.cat(
        [
            torch.nn.functional.unfold(
                img,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            )  # [B, C*kernel_size*kernel_size, num_patches] [1, 3*32*32, 16]
            .permute(
                0, 2, 1
            )  # [B, num_patches, C*kernel_size*kernel_size] [1, 16, 3*32*32]
            .reshape(
                -1, n_channels, kernel_size, kernel_size
            )  # [B*num_patches, C, kernel_size, kernel_size] [16, 3, 32, 32]
            for img in ref_imgs
        ],
        dim=0,
    )


class OptimalDenoiser:
    def __init__(
        self,
        dataset: Union[list[torch.Tensor], torch.Tensor],
        scheduler,
        *args,
        temperature=1.0,
        **kwargs,
    ):
        if dataset is None:
            self.data = None
        else:
            if isinstance(dataset, list):
                data = torch.cat(dataset, dim=0)
            else:
                data = dataset
            self.data = data.flatten(start_dim=1)
        self.scheduler = scheduler
        self.temperature = temperature
        self.trained = False

    def is_trained(self) -> bool:
        """Check if the denoiser has been trained"""
        return self.trained

    def add_data(self, data: Union[list[torch.Tensor], torch.Tensor]):
        self.data_shape = data.shape
        if isinstance(data, list):
            data = torch.cat(data, dim=0)
        if self.data is None:
            self.data = data.flatten(start_dim=1)
        else:
            self.data = torch.cat([self.data, data.flatten(start_dim=1)], dim=0)
        return self

    def save(self, path: str):
        """Save denoiser data"""
        if not self.trained:
            raise ValueError("Denoiser must be trained before saving")
        torch.save(
            {
                "data": self.data,
                "temperature": self.temperature,
            },
            path,
        )

    def load(self, path: str):
        """Load denoiser data"""
        saved_data = torch.load(path, weights_only=True)
        self.data = saved_data["data"]
        self.trained = True
        return self

    def train(self):
        self.trained = True
        return self

    def _get_distances_and_images(self, x, alpha):
        # Move computation to x's device
        d = self.data.to(x.device)

        # Compute L2 distances more efficiently using matrix operations
        # (a-b)^2 = a^2 + b^2 - 2ab
        x_sq = torch.sum(x * x, dim=1, keepdim=True)  # [B, 1]
        d_sq = torch.sum(d * d, dim=1)  # [N]
        xd = torch.matmul(x, d.t())  # [B, N]

        # Compute squared distances
        sq_distances = x_sq + d_sq.unsqueeze(0) - 2 * xd  # [B, N]

        # Scale distances and expand database vectors for batch processing
        scaled_distances = (
            sq_distances * alpha
        )  # Originally it should be sqrt(a), but we square it cuz L2 distance
        expanded_data = d.unsqueeze(0).expand(x.size(0), -1, -1)  # [B, N, D]

        return scaled_distances, expanded_data

    def _get_scaled_x(self, x, t, noisy_subspace=-1):
        # Get current sigma from scheduler
        alpha_prod_t = self.scheduler.alphas_cumprod[t]
        beta_prod_t = 1 - alpha_prod_t  # beta_prod_t

        # If data*alpha = x, then data = x/alpha
        x_scaled = x.clone() / (alpha_prod_t**0.5)
        if noisy_subspace > -1:
            x_scaled[:, noisy_subspace:, ...] = x_scaled[:, noisy_subspace:, ...] * (
                alpha_prod_t**0.5
            )

        x_scaled = x_scaled.flatten(start_dim=1)  # [B, D]
        return x_scaled, alpha_prod_t, beta_prod_t

    def __call__(self, x, t, noisy_subspace=-1, res_scale=1.0):
        # x: [B, C, H, W]
        orig_shape = x.shape
        x_scaled, alpha_prod_t, beta_prod_t = self._get_scaled_x(x, t, noisy_subspace)

        # print(x_scaled.shape, self.data.shape)
        sq_diffs, data = self._get_distances_and_images(x_scaled, alpha_prod_t)

        sq_diffs = sq_diffs * (res_scale**2)  # adjust for the downscaled resolution
        weights = F.softmax(
            -sq_diffs / 2 / beta_prod_t / self.temperature,
            dim=1,  # Originally it should be sqrt(a), but we square it cuz L2 distance
        )  # [B, N]
        # weights has shape [16, 48]
        # Perform batch matrix multiplication
        x0_mean = torch.bmm(weights.unsqueeze(1), data).squeeze(1)  # [16, 3072]
        # eps = (x - x0_mean) / beta_prod_t
        return x0_mean.view(orig_shape)


class KNNOptimalDenoiser(OptimalDenoiser):
    def __init__(
        self,
        *args,
        num_neighbors=2000,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.num_neighbors = num_neighbors
        if self.data is not None:
            self.data = self.data.cpu()

    def is_trained(self) -> bool:
        """Check if the KNN denoiser has been trained"""
        return self.trained and hasattr(self, "index")

    def save(self, path: str):
        """Save KNN denoiser data and index"""
        if not self.trained:
            raise ValueError("Denoiser must be trained before saving")

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(path), exist_ok=True)

        # Save FAISS index
        faiss.write_index(self.index, f"{path}.index")
        # Save other data
        torch.save(
            {
                "data": self.data,
                "temperature": self.temperature,
                "num_neighbors": self.num_neighbors,
                "dim": self.dim,
            },
            f"{path}.data",
        )

    def load(self, path: str):
        """Load KNN denoiser data and index"""
        # Load FAISS index
        self.index = faiss.read_index(f"{path}.index")
        # Load other data
        saved_data = torch.load(f"{path}.data", weights_only=True)
        self.data = saved_data["data"]
        self.dim = saved_data["dim"]
        self.trained = True
        return self

    def add_data(self, data: Union[list[torch.Tensor], torch.Tensor]):
        self.data_shape = data.shape
        if isinstance(data, list):
            data = torch.cat(data, dim=0)
        if self.data is None:
            self.data = data.flatten(start_dim=1).cpu()
        else:
            self.data = torch.cat([self.data, data.flatten(start_dim=1).cpu()], dim=0)
        return self

    def train(self):
        """Create a new FAISS index"""
        self.dim = self.data.shape[1]
        num_patches = self.data.shape[0]

        if num_patches > 1_000_000:
            print(f"Training KNN denoiser with {num_patches} patches")
            # For very large datasets, use IVF
            nlist = min(4096, num_patches // 39)
            quantizer = faiss.IndexFlatL2(self.dim)
            index = faiss.IndexIVFFlat(quantizer, self.dim, nlist)
            # Train on a subset if dataset is too large
            train_size = min(100_000, num_patches)
            train_data = self.data[:train_size].flatten(start_dim=1).cpu().numpy()
            index.train(train_data)
            print("Done training index")
        else:
            # For smaller datasets, use exact search
            index = faiss.IndexFlatL2(self.dim)

        # Add vectors to index
        print("Adding data to index")
        index.add(self.data.numpy())
        print("Done adding data to index")
        self.index = index
        self.trained = True
        return self

    def _get_distances_and_images(self, x, alpha):
        # NOTE: Scale query vectors instead of database vectors
        # Search nearest neighbors with scaled query

        D, I = self.index.search(
            x.cpu().numpy(), min(self.num_neighbors, self.data.shape[0])
        )
        sq_diffs = torch.from_numpy(D).to(x.device) * (
            alpha
        )  # Originally it should be sqrt(a), but we square it cuz L2 distance

        data = self.data[I].to(x.device)
        return sq_diffs, data


class KNNCondSingleIndexDenoiser(KNNOptimalDenoiser):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

    def init_cond(self, cond_name: str, cond_mapping: Callable, sample_patch_size):
        self.cond_name = cond_name
        self.cond_mapping = cond_mapping
        self.cond_data = None
        self.cond_dim = None
        self.cond_index = None
        self.n_patches_per_img = sample_patch_size[0]

    def _get_distances_and_images(self, x, cond_x, alpha, cond_gamma):
        # print(x.shape, cond_x.shape, self.data.shape, self.cond_data.shape)
        # Get nearest neighbors from both indices
        D1, I1 = self.index.search(
            x.cpu().numpy(), min(self.num_neighbors, self.data.shape[0])
        )

        D2, I2 = self.index.search(
            cond_x.cpu().numpy(), min(self.num_neighbors, self.data.shape[0])
        )

        # print(
        #     D1.shape, I1.shape, D2.shape, I2.shape, self.num_neighbors, self.data.shape
        # )

        # Convert to torch tensors and move to device
        D1 = torch.from_numpy(D1).to(x.device)
        D2 = torch.from_numpy(D2).to(x.device)
        I1 = torch.from_numpy(I1)
        I2 = torch.from_numpy(I2)

        # Calculate cross distances
        # For points found by index1, calculate their distances in cond space
        data_1 = self.data[I1].to(x.device)
        # cond_data_1, alpha_prod_t = self.cond_mapping(data_1.view(-1, data_1.shape[-1]))
        cond_data_1 = data_1  # cond_data_1.view(data_1.shape)
        cross_D1 = (
            torch.cdist(cond_x.to(x.device)[:, None, :], cond_data_1, p=2)
            .pow(2)
            .squeeze(1)
        )

        # For points found by index2, calculate their distances in normal space
        data_2 = self.data[I2].to(x.device)
        cross_D2 = (
            torch.cdist(x.to(x.device)[:, None, :], data_2, p=2).pow(2).squeeze(1)
        )

        # Recalculate D2 as the distance in cond space
        cond_data_2 = data_2
        cond_data_2 = cond_data_2.view(data_2.shape)
        D2_cond = (
            torch.cdist(cond_x.to(x.device)[:, None, :], cond_data_2, p=2)
            .pow(2)
            .squeeze(1)
        )

        # print(f"D1       || min: {torch.min(D1)} max: {torch.max(D1)} mean: {torch.mean(D1)} std: {torch.std(D1)}")
        # print(f"D2       || min: {torch.min(D2)} max: {torch.max(D2)} mean: {torch.mean(D2)} std: {torch.std(D2)}")
        # print(f"cross_D1 || min: {torch.min(cross_D1)} max: {torch.max(cross_D1)} mean: {torch.mean(cross_D1)} std: {torch.std(cross_D1)}")
        # print(f"cross_D2 || min: {torch.min(cross_D2)} max: {torch.max(cross_D2)} mean: {torch.mean(cross_D2)} std: {torch.std(cross_D2)}")
        # print(f"D2_cond  || min: {torch.min(D2_cond)} max: {torch.max(D2_cond)} mean: {torch.mean(D2_cond)} std: {torch.std(D2_cond)}")

        # Combine all unique indices and their corresponding data
        combined_I = torch.cat([I1, I2], dim=1)
        combined_data = self.data[combined_I].to(x.device)
        # print(cond_gamma / torch.sqrt(alpha_prod_t))
        combined_D = alpha * torch.cat([D1, cross_D2], dim=1) + cond_gamma * torch.cat(
            [cross_D1, D2_cond], dim=1
        )
        # combined_D /= (1 + cond_gamma)
        # Scale by alpha only the noisy part
        return combined_D, combined_data

    def __call__(self, x, cond_x, t, cond_gamma=1.0, noisy_subspace=-1, res_scale=1.0):
        # x: [B, C, H, W]
        orig_shape = x.shape
        x_scaled, alpha_prod_t, beta_prod_t = self._get_scaled_x(x, t, noisy_subspace)

        # print(x_scaled.shape, self.data.shape)
        sq_diffs, data = self._get_distances_and_images(
            x_scaled, cond_x.flatten(start_dim=1), alpha_prod_t, cond_gamma
        )

        sq_diffs = sq_diffs * (res_scale**2)  # adjust for the downscaled resolution
        weights = F.softmax(
            -sq_diffs / 2 / beta_prod_t / self.temperature,
            dim=1,  # Originally it should be sqrt(a), but we square it cuz L2 distance
        )  # [B, N]
        # weights has shape [16, 48]
        # Perform batch matrix multiplication
        x0_mean = torch.bmm(weights.unsqueeze(1), data).squeeze(1)  # [16, 3072]
        # eps = (x - x0_mean) / beta_prod_t

        # print("\n\n\n\n\n")
        # print("x shape: ", x.shape, "x0_mean shape: ", x0_mean.shape)
        # print("\n\n\n\n\n")
        return x0_mean.view(orig_shape)


class KNNCondDenoiser(KNNOptimalDenoiser):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.flat_idx_threshold = 1_000_000  # 100_000_000

    def init_cond(self, cond_name: str, cond_mapping: Callable, sample_patch_size):
        self.cond_name = cond_name
        self.cond_mapping = cond_mapping
        self.cond_data = None
        self.cond_dim = None
        self.cond_index = None
        self.n_patches_per_img = sample_patch_size[0]

    def save(self, path: str):
        """Save KNN denoiser data and index"""
        if not self.trained:
            raise ValueError("Denoiser must be trained before saving")

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(path), exist_ok=True)

        # Save FAISS index
        faiss.write_index(self.index, f"{path}.index")
        faiss.write_index(self.index, f"{path}_cond_{self.cond_name}.index")

        # Save other data
        torch.save(
            {
                "data": self.data,
                "temperature": self.temperature,
                "num_neighbors": self.num_neighbors,
                "dim": self.dim,
            },
            f"{path}.data",
        )
        torch.save(
            {
                "cond_data": self.cond_data,
                "cond_dim": self.cond_dim,
            },
            f"{path}_cond_{self.cond_name}.data",
        )

    def _train_and_save_cond_index(self, path: str, device: str = "cuda"):
        # Process data in batches to create cond_data
        batch_size = 1000 * self.n_patches_per_img
        self.cond_data = []

        for i in range(0, len(self.data), batch_size):
            batch = self.data[i : i + batch_size].to(device)
            cond_batch = self.cond_mapping(batch)
            self.cond_data.append(cond_batch.cpu())

        self.cond_data = torch.cat(self.cond_data, dim=0)
        self.cond_dim = self.cond_data.shape[1]

        # Create and train cond index
        num_patches = self.data.shape[0]

        print(f"Training cond index with {num_patches} patches")
        if num_patches > self.flat_idx_threshold:
            nlist = min(4096, num_patches // 39)
            cond_index_quantizer = faiss.IndexFlatL2(self.cond_dim)
            cond_index = faiss.IndexIVFFlat(cond_index_quantizer, self.cond_dim, nlist)
            # Train on a subset if dataset is too large
            train_size = min(200_000, num_patches)
            train_cond_data = (
                self.cond_data[:train_size].flatten(start_dim=1).cpu().numpy()
            )
            cond_index.train(train_cond_data)
            print("Done training cond index")
        else:
            # For smaller datasets, use exact search
            cond_index = faiss.IndexFlatL2(self.cond_dim)

        # Add vectors to index
        print("Adding data to cond index")
        cond_index.add(self.cond_data.numpy())
        print("Done adding data to cond index")
        self.cond_index = cond_index

        # Save the conditional index and data
        faiss.write_index(self.cond_index, f"{path}_cond_{self.cond_name}.index")
        torch.save(
            {
                "cond_data": self.cond_data,
                "cond_dim": self.cond_dim,
            },
            f"{path}_cond_{self.cond_name}.data",
        )

    def load(self, path: str, device: str = "cuda"):
        """Load KNN denoiser data and index"""
        time_start = time.time()
        # Load FAISS index
        print(f"Loading index from {path}.index")
        self.index = faiss.read_index(f"{path}.index")
        print(f"Done loading index in {time.time() - time_start} seconds")
        # Load other data
        time_start = time.time()
        print(f"Loading data from {path}.data")
        saved_data = torch.load(
            f"{path}.data",
            weights_only=True,
        )
        print(f"Done loading in {time.time() - time_start} seconds")
        self.data = saved_data["data"]
        self.dim = saved_data["dim"]

        # Check if cond data exists
        if os.path.exists(f"{path}_cond_{self.cond_name}.data"):
            print(f"Loading cond index for {self.cond_name}")
            self.cond_index = faiss.read_index(f"{path}_cond_{self.cond_name}.index")
            saved_cond_data = torch.load(
                f"{path}_cond_{self.cond_name}.data", weights_only=True
            )
            self.cond_data = saved_cond_data["cond_data"]
            self.cond_dim = saved_cond_data["cond_dim"]
        else:
            print(
                f"Training cond index for {self.cond_name} and will save it to {path}_cond_{self.cond_name}.data"
            )
            self._train_and_save_cond_index(path, device)

        self.trained = True
        return self

    def add_data(self, data: Union[list[torch.Tensor], torch.Tensor]):
        self.data_shape = data.shape
        if isinstance(data, list):
            data = torch.cat(data, dim=0)

        if self.data is None:
            self.data = data.flatten(start_dim=1).cpu()
        else:
            self.data = torch.cat([self.data, data.flatten(start_dim=1).cpu()], dim=0)

        if self.cond_data is None:
            self.cond_data = self.cond_mapping(data).flatten(start_dim=1).cpu()
        else:
            self.cond_data = torch.cat(
                [self.cond_data, self.cond_mapping(data).flatten(start_dim=1).cpu()],
                dim=0,
            )
        return self

    def train(self):
        """Create a new FAISS index"""
        self.dim = self.data.shape[1]
        self.cond_dim = self.cond_data.shape[1]
        num_patches = self.data.shape[0]

        print(f"Training KNN denoiser with {num_patches} patches")
        # For very large datasets, use IVF
        if num_patches > self.flat_idx_threshold:
            nlist = min(4096, num_patches // 39)
            quantizer = faiss.IndexFlatL2(self.dim)
            index = faiss.IndexIVFFlat(quantizer, self.dim, nlist)
            # Train on a subset if dataset is too large
            train_size = min(100_000, num_patches)
            train_data = self.data[:train_size].flatten(start_dim=1).cpu().numpy()
            index.train(train_data)
            print("Done training index")
        else:
            # For smaller datasets, use exact search
            index = faiss.IndexFlatL2(self.dim)

        print(f"Training cond index with {num_patches} patches")
        if num_patches > self.flat_idx_threshold:
            nlist = min(4096, num_patches // 39)
            cond_index_quantizer = faiss.IndexFlatL2(self.cond_dim)
            cond_index = faiss.IndexIVFFlat(cond_index_quantizer, self.cond_dim, nlist)
            # Train on a subset if dataset is too large
            train_size = min(100_000, num_patches)
            train_cond_data = (
                self.cond_data[:train_size].flatten(start_dim=1).cpu().numpy()
            )
            cond_index.train(train_cond_data)
            print("Done training cond index")
        else:
            # For smaller datasets, use exact search
            cond_index = faiss.IndexFlatL2(self.cond_dim)

        # Add vectors to index
        print("Adding data to index")
        index.add(self.data.numpy())
        cond_index.add(self.cond_data.numpy())
        print("Done adding data to index")
        self.index = index
        self.cond_index = cond_index
        self.trained = True
        return self

    def _get_distances_and_images(self, x, cond_x, alpha, cond_gamma):
        # print(x.shape, cond_x.shape, self.data.shape, self.cond_data.shape)
        # Get nearest neighbors from both indices
        D1, I1 = self.index.search(
            x.cpu().numpy(), min(self.num_neighbors, self.data.shape[0])
        )

        D2, I2 = self.cond_index.search(
            cond_x.cpu().numpy(), min(self.num_neighbors, self.cond_data.shape[0])
        )

        # Convert to torch tensors and move to device
        D1 = torch.from_numpy(D1).to(x.device)
        D2 = torch.from_numpy(D2).to(x.device)
        I1 = torch.from_numpy(I1)
        I2 = torch.from_numpy(I2)

        # Calculate cross distances
        # For points found by index1, calculate their distances in cond space
        cond_data_1 = self.cond_data[I1].to(x.device)
        cross_D1 = (
            torch.cdist(cond_x.to(x.device)[:, None, :], cond_data_1, p=2)
            .pow(2)
            .squeeze(1)
        )

        # For points found by index2, calculate their distances in normal space
        data_2 = self.data[I2].to(x.device)
        cross_D2 = (
            torch.cdist(x.to(x.device)[:, None, :], data_2, p=2).pow(2).squeeze(1)
        )

        # Combine all unique indices and their corresponding data
        combined_I = torch.cat([I1, I2], dim=1)
        combined_data = self.data[combined_I].to(x.device)

        # Combine distances with weighting
        print(torch.mean(x), torch.mean(cond_x))
        print(
            # f"min I1: {min(I1)}, min I2: {min(I2)}, "
            f"Dataset sizes: {self.data.shape[0]}, {self.cond_data.shape[0]} \n"
            f"Num of shapes requested: {min(self.num_neighbors, self.cond_data.shape[0])} \n"
            f"num entries equal to -1 in I1: {(I1 == -1).sum().item()}, \n"
            f"Some samples: {I1[:10]} \n"
            f"num entries equal to -1 in I2: {(I2 == -1).sum().item()} \n"
            f"Some samples: {I2[:10]} \n"
        )
        print(
            torch.mean(D1),
            torch.mean(cross_D2),
            torch.mean(cross_D1),
            torch.mean(D2),
            alpha,
            cond_gamma,
        )
        combined_D = alpha * torch.cat([D1, cross_D2], dim=1) + cond_gamma * torch.cat(
            [cross_D1, D2], dim=1
        )
        # Scale by alpha only the noisy part

        return combined_D, combined_data

    def __call__(self, x, cond_x, t, cond_gamma=1.0, noisy_subspace=-1, res_scale=1.0):
        # x: [B, C, H, W]
        orig_shape = x.shape
        x_scaled, alpha_prod_t, beta_prod_t = self._get_scaled_x(x, t, noisy_subspace)

        # print(x_scaled.shape, self.data.shape)
        sq_diffs, data = self._get_distances_and_images(
            x_scaled, cond_x.flatten(start_dim=1), alpha_prod_t, cond_gamma
        )

        sq_diffs = sq_diffs * (res_scale**2)  # adjust for the downscaled resolution
        weights = F.softmax(
            -sq_diffs / 2 / beta_prod_t / self.temperature,
            dim=1,  # Originally it should be sqrt(a), but we square it cuz L2 distance
        )  # [B, N]
        # weights has shape [16, 48]
        # Perform batch matrix multiplication
        x0_mean = torch.bmm(weights.unsqueeze(1), data).squeeze(1)  # [16, 3072]
        # eps = (x - x0_mean) / beta_prod_t
        return x0_mean.view(orig_shape)


@torch.no_grad()
def match_img_histograms(src_img: torch.Tensor, dst_img: torch.Tensor) -> torch.Tensor:
    # Flatten all dimensions after the first two (batch and channel)
    flat_src = src_img.flatten(2)
    flat_dst = dst_img.flatten(2)

    _, sort_idx_src = torch.sort(flat_src, dim=-1)
    sort_values_dst, _ = torch.sort(flat_dst, dim=-1)

    matched = torch.zeros_like(flat_src)
    matched.scatter_(-1, sort_idx_src, sort_values_dst)
    return matched.view_as(src_img)


@torch.no_grad()
def heeger_bergen_synthesis(
    reference_image: torch.Tensor,
    pyramid: Pyramid,
    n_iters: int,
    display_callback: Callable = None,
) -> torch.Tensor:
    dst_pyramid = pyramid.encode(reference_image)
    cur_image = torch.randn_like(dst_pyramid.initial_image)

    for i in tqdm(range(n_iters)):
        src_pyramid = pyramid.encode(cur_image)

        for level_idx in range(
            len(src_pyramid.levels) - 1, -1, -1
        ):  # going in reverse order
            for band_filter_idx in range(len(src_pyramid.levels[level_idx])):
                src_pyramid.levels[level_idx][band_filter_idx] = match_img_histograms(
                    src_pyramid.levels[level_idx][band_filter_idx],
                    dst_pyramid.levels[level_idx][band_filter_idx],
                )

        cur_image = pyramid.decode(src_pyramid)
        cur_image = match_img_histograms(cur_image, dst_pyramid.initial_image)

        if display_callback is not None:
            display_callback(i, n_iters, cur_image, src_pyramid, dst_pyramid)

    return cur_image


@torch.no_grad()
def heeger_bergen_coarse_to_fine(
    reference_image: torch.Tensor,
    pyramid: Pyramid,
    n_iters: int,
    display_callback: Callable = None,
) -> torch.Tensor:
    dst_pyramid = pyramid.encode(reference_image)
    cur_image = torch.randn_like(dst_pyramid.initial_image)

    for i in tqdm(range(n_iters)):
        src_pyramid = pyramid.encode(cur_image)
        for level_idx in range(
            len(src_pyramid.levels) - 1, -1, -1
        ):  # going in reverse order
            src_pyramid = pyramid.encode(cur_image)
            for band_filter_idx in range(len(src_pyramid.levels[level_idx])):
                src_pyramid.levels[level_idx][band_filter_idx] = match_img_histograms(
                    src_pyramid.levels[level_idx][band_filter_idx],
                    dst_pyramid.levels[level_idx][band_filter_idx],
                )

            cur_image = pyramid.decode(src_pyramid)

        src_pyramid = pyramid.encode(cur_image)
        cur_image = pyramid.decode(src_pyramid)
        cur_image = match_img_histograms(cur_image, dst_pyramid.initial_image)

        if display_callback is not None:
            display_callback(i, n_iters, cur_image, src_pyramid, dst_pyramid)

    return cur_image


def get_autoencoders(
    encoder_hidden_dim: int, input_dim: int = 75, multipliers=[32, 16, 8]
):
    # Create encoder layers using multipliers
    encoder_layers = []
    current_dim = input_dim
    for multiplier in multipliers:
        next_dim = encoder_hidden_dim * multiplier
        encoder_layers.extend(
            [
                torch.nn.Linear(current_dim, next_dim),
                torch.nn.LayerNorm(next_dim),
                torch.nn.LeakyReLU(0.2),
                torch.nn.Dropout(0.1),
            ]
        )
        current_dim = next_dim

    # Add final layer
    encoder_layers.extend(
        [torch.nn.Linear(current_dim, encoder_hidden_dim), torch.nn.Tanh()]
    )
    encoder = torch.nn.Sequential(*encoder_layers)
    current_dim = encoder_hidden_dim

    # Create decoder layers using reversed multipliers
    decoder_layers = []
    for multiplier in reversed(multipliers):
        next_dim = encoder_hidden_dim * multiplier
        decoder_layers.extend(
            [
                torch.nn.Linear(current_dim, next_dim),
                torch.nn.LayerNorm(next_dim),
                torch.nn.LeakyReLU(0.2),
                torch.nn.Dropout(0.1),
            ]
        )
        current_dim = next_dim
    decoder_layers.append(torch.nn.Linear(current_dim, input_dim))
    decoder = torch.nn.Sequential(*decoder_layers)
    return encoder, decoder


class ParallelModel(torch.nn.Module):
    def __init__(self, model1, model2):
        super().__init__()

        self.model1 = model1
        self.model2 = model2

    def forward(self, x):
        # Split input into two halves
        half1 = x[:, : x.shape[1] // 2]
        half2 = x[:, x.shape[1] // 2 :]

        # Process each half
        output1 = self.model1(half1)
        output2 = self.model2(half2)

        # Combine outputs
        return torch.cat([output1, output2], dim=1)


class EncoderOptimalDenoiser(OptimalDenoiser):
    def __init__(
        self,
        dataset: Union[list[torch.Tensor], torch.Tensor],
        scheduler,
        encoder_hidden_dim: int = 32,
        use_knn: bool = True,
        num_neighbors: int = 2000,
        temperature: float = 1.0,
        device: str = "cuda",
        level_idx: int = 0,  # Add level index for prefixing metrics
    ):
        super().__init__(dataset, scheduler, temperature)
        self.use_knn = use_knn
        self.num_neighbors = num_neighbors
        self.encoder_hidden_dim = encoder_hidden_dim
        self.device = device
        self.level_idx = level_idx  # Store level index

        # Create autoencoder components
        self.encoder = None
        self.decoder = None

        # Create internal denoiser for latent space
        self.latent_denoiser = KNNOptimalDenoiser if use_knn else OptimalDenoiser
        self.latent_denoiser = self.latent_denoiser(
            dataset=None,  # Will be set during training
            scheduler=scheduler,
            num_neighbors=num_neighbors,
            temperature=temperature,
        )

    def _init_autoencoder(self, input_dim):
        """Initialize autoencoder architecture"""
        if self.data_shape[1] == 6:
            enc1, dec1 = get_autoencoders(self.encoder_hidden_dim, input_dim // 2)
            enc2, dec2 = get_autoencoders(self.encoder_hidden_dim, input_dim // 2)
            self.encoder = ParallelModel(enc1, enc2).to(self.device)
            self.decoder = ParallelModel(dec1, dec2).to(self.device)
        else:
            enc, dec = get_autoencoders(self.encoder_hidden_dim, input_dim)
            self.encoder = enc.to(self.device)
            self.decoder = dec.to(self.device)

    def train(self, num_epochs=100, batch_size=1024):
        """Train autoencoder and initialize latent denoiser"""
        data = self.data.to(self.device)
        self._init_autoencoder(self.data.shape[1])

        # Split data into train and validation sets (90/10 split)
        num_val = max(1, int(0.1 * len(data)))
        indices = torch.randperm(len(data))
        train_indices = indices[:-num_val]
        val_indices = indices[-num_val:]

        train_data = data[train_indices]
        val_data = data[val_indices]

        # Make sure models are in training mode and gradients are enabled
        self.encoder.train()
        self.decoder.train()
        torch.set_grad_enabled(True)  # Explicitly enable gradients

        optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=1e-4
        )

        print(f"Training autoencoder for level {self.level_idx}...")
        best_val_loss = float("inf")

        for epoch in range(num_epochs):
            # Training
            self.encoder.train()
            self.decoder.train()
            perm = torch.randperm(len(train_data))
            total_train_loss = 0

            for i in range(0, len(train_data), batch_size):
                batch_idx = perm[i : i + batch_size]
                batch = train_data[batch_idx]

                # Forward pass
                batch = batch.requires_grad_(True)
                latent = self.encoder(batch)
                recon = self.decoder(latent)
                loss = F.mse_loss(recon, batch)

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_train_loss += loss.item() * len(batch)

            avg_train_loss = total_train_loss / len(train_data)

            # Validation
            self.encoder.eval()
            self.decoder.eval()
            with torch.no_grad():
                # Compute validation loss
                val_latent = self.encoder(val_data)
                val_recon = self.decoder(val_latent)
                val_loss = F.mse_loss(val_recon, val_data)

                # Log metrics with level prefix
                metrics = {
                    f"level_{self.level_idx}_denoiser_train_loss": avg_train_loss,
                    f"level_{self.level_idx}_denoiser_val_loss": val_loss.item(),
                    f"level_{self.level_idx}_denoiser_epoch": epoch + 1,
                }

                # Every 10 epochs, log reconstruction images
                if (epoch + 1) % 10 == 0:
                    # Select a few validation samples for visualization
                    num_vis = min(5, len(val_data))
                    vis_data = val_data[:num_vis]
                    vis_recon = self.decoder(self.encoder(vis_data))

                    # Reshape if the data is flattened image patches
                    if len(vis_data.shape) == 2:  # flattened data
                        if self.data_shape[1] == 3:
                            vis_data = vis_data.view(
                                num_vis, 3, self.data_shape[2], self.data_shape[3]
                            )
                            vis_recon = vis_recon.view(
                                num_vis, 3, self.data_shape[2], self.data_shape[3]
                            )
                        else:
                            vis_data = vis_data.view(
                                2 * num_vis, 3, self.data_shape[2], self.data_shape[3]
                            )
                            vis_recon = vis_recon.view(
                                2 * num_vis, 3, self.data_shape[2], self.data_shape[3]
                            )

                    # Create comparison grid
                    vis_grid = []
                    for orig, recon in zip(vis_data, vis_recon):
                        # Ensure proper range for visualization
                        orig = (orig - orig.min()) / (orig.max() - orig.min())
                        recon = (recon - recon.min()) / (recon.max() - recon.min())
                        vis_grid.extend([orig, recon])

                    grid = torchvision.utils.make_grid(
                        vis_grid,
                        nrow=2,  # original and reconstruction side by side
                        normalize=False,
                    )

                    metrics[
                        f"level_{self.level_idx}_denoiser_reconstructions"
                    ] = wandb.Image(
                        grid.cpu(),
                        caption=f"Level {self.level_idx} - Epoch {epoch+1} - Original vs Reconstruction",
                    )

                # Log to parent run
                if wandb.run is not None:
                    wandb.log(metrics)

                # Save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if wandb.run is not None:
                        wandb.run.summary[
                            f"level_{self.level_idx}_denoiser_best_val_loss"
                        ] = best_val_loss

            if (epoch + 1) % 10 == 0:
                print(
                    f"Level {self.level_idx} - Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}"
                )

        # Encode all data points for latent denoiser
        print(f"Preparing latent denoiser for level {self.level_idx}...")
        with torch.no_grad():
            encoded_data = self.encoder(data)
            # Move encoded data to CPU for the latent denoiser
            self.latent_denoiser.add_data(encoded_data.cpu())

        # Train latent denoiser
        print(f"Training latent space denoiser for level {self.level_idx}...")
        self.latent_denoiser.train()

        self.trained = True
        return self

    def __call__(self, x, t, noisy_subspace=-1, res_scale=1.0):
        """Denoise using latent space"""
        orig_shape = x.shape
        x_flat = x.flatten(start_dim=1)

        # Ensure input is on the correct device
        x_flat = x_flat.to(self.device)

        # Encode input to latent space
        with torch.no_grad():
            x_encoded = self.encoder(x_flat)
            # Move to CPU for latent denoiser
            x_encoded = x_encoded.cpu()

        # Denoise in latent space using internal denoiser
        latent_denoised = self.latent_denoiser(x_encoded, t, noisy_subspace, res_scale)

        # Decode back to original space
        with torch.no_grad():
            # Move back to training device for decoding
            latent_denoised = latent_denoised.to(self.device)
            x_decoded = self.decoder(latent_denoised)

        return x_decoded.view(orig_shape)

    def save(self, path: str):
        """Save encoder, decoder, and latent denoiser"""
        if not self.trained:
            raise ValueError("Denoiser must be trained before saving")

        os.makedirs(os.path.dirname(path), exist_ok=True)

        # Save models and data
        torch.save(
            {
                "data": self.data,  # Save the data
                "encoder_state": self.encoder.state_dict(),
                "decoder_state": self.decoder.state_dict(),
                "temperature": self.temperature,
                "num_neighbors": self.num_neighbors,
                "use_knn": self.use_knn,
                "encoder_hidden_dim": self.encoder_hidden_dim,
                "device": self.device,
                "data_shape": self.data_shape,
            },
            f"{path}.data",
        )

        # Save latent denoiser
        self.latent_denoiser.save(f"{path}_latent")

    def load(self, path: str):
        """Load saved encoder, decoder, and latent denoiser"""
        # Load the saved data
        saved_data = torch.load(f"{path}.data")

        # Load the data first
        self.data = saved_data["data"]
        self.data_shape = saved_data["data_shape"]

        # Set device from saved data
        self.device = saved_data.get("device", self.device)

        # Initialize and load autoencoder
        if self.encoder is None:
            self._init_autoencoder(self.data.shape[1])
        else:
            # Move existing models to correct device
            self.encoder = self.encoder.to(self.device)
            self.decoder = self.decoder.to(self.device)

        self.encoder.load_state_dict(saved_data["encoder_state"])
        self.decoder.load_state_dict(saved_data["decoder_state"])

        # Load other attributes
        self.temperature = saved_data["temperature"]
        self.num_neighbors = saved_data["num_neighbors"]
        self.use_knn = saved_data["use_knn"]
        self.encoder_hidden_dim = saved_data["encoder_hidden_dim"]

        # Load latent denoiser
        self.latent_denoiser.load(f"{path}_latent")

        self.trained = True
        return self
