"""X0-level Gaussian guidance using Wiener denoiser difference.

Simplified from efficient_diffusion_steering/_shared/gaussian_x0_guidance.py
and denoisers/gaussian_denoiser.py
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, Dict

import numpy as np
import torch
import torch.nn as nn


@dataclass
class GuidanceWindow:
    """Define step/sigma range for guidance application."""
    start_step: int = 0
    end_step: int = 10**9
    start_sigma: Optional[float] = None
    end_sigma: Optional[float] = None

    def contains(self, step: int, sigma: Optional[float]) -> bool:
        if self.start_sigma is not None and self.end_sigma is not None and sigma is not None:
            lo = min(self.start_sigma, self.end_sigma)
            hi = max(self.start_sigma, self.end_sigma)
            return lo <= sigma <= hi
        return self.start_step <= step <= self.end_step


class GaussianDenoiser(nn.Module):
    """Full-image Gaussian Denoiser using PCA/Wiener filtering.

    Args:
        filter_dict: Dictionary containing:
            - 'eigenvectors': PCA eigenvectors [D, K]
            - 'eigenvalues': PCA eigenvalues [K]
            - 'mean_patch': Mean image [D]
            - 'img_resolution': Image resolution
            - 'img_channels': Number of channels
        num_components: Number of PCA components to use
        denoising_mode: 'wiener' (σ-dependent) or 'one_step'
    """

    def __init__(
        self,
        filter_dict: Dict[str, torch.Tensor],
        num_components: Optional[int] = None,
        denoising_mode: str = 'wiener',
        epsilon: float = 1e-8
    ):
        super().__init__()
        self.epsilon = epsilon
        self.denoising_mode = denoising_mode.lower()

        # Extract filter components
        self.register_buffer('mean', filter_dict['mean_patch'])
        self.register_buffer('eigenvectors', filter_dict['eigenvectors'])
        self.register_buffer('eigenvalues', filter_dict['eigenvalues'])

        # Image properties
        if 'img_resolution' in filter_dict:
            self.img_resolution = int(filter_dict['img_resolution'])
        elif 'patch_size' in filter_dict:
            self.img_resolution = int(filter_dict['patch_size'])
        else:
            # Infer from eigenvector dimension
            D = self.eigenvectors.shape[0]
            for C in [3, 1, 4]:
                res = int((D / C) ** 0.5)
                if res * res * C == D:
                    self.img_resolution = res
                    break
            else:
                raise ValueError("Cannot infer img_resolution")

        self.img_channels = filter_dict.get('img_channels', 3)

        # Select components
        total_components = self.eigenvectors.shape[1]
        if num_components is None:
            self.num_components = total_components
        else:
            self.num_components = min(num_components, total_components)

        self.register_buffer('top_eigenvectors', self.eigenvectors[:, :self.num_components])
        self.register_buffer('top_eigenvalues', self.eigenvalues[:self.num_components])

    def denoise(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
        """Apply full-image Gaussian denoising.

        Args:
            x: Noisy images [B, C, H, W]
            sigma: Noise level

        Returns:
            Denoised images [B, C, H, W]
        """
        B, C, H, W = x.shape
        original_shape = x.shape

        # Flatten to [B, D]
        x_flat = x.reshape(B, -1).to(self.eigenvectors.dtype)

        # Center
        centered = x_flat - self.mean.reshape(1, -1)

        # Project to PCA space
        coeffs = torch.matmul(centered, self.top_eigenvectors)

        # Apply filtering
        if self.denoising_mode == 'wiener':
            wiener = self.top_eigenvalues / (self.top_eigenvalues + sigma**2 + self.epsilon)
            filtered_coeffs = coeffs * wiener.reshape(1, -1)
        else:  # one_step
            scaling = torch.sqrt(self.top_eigenvalues / (self.top_eigenvalues + sigma**2 + self.epsilon))
            scaling = torch.nan_to_num(scaling, nan=0.0)
            filtered_coeffs = coeffs * scaling.reshape(1, -1)

        # Project back
        denoised_centered = torch.matmul(filtered_coeffs, self.top_eigenvectors.T)
        denoised_flat = denoised_centered + self.mean.reshape(1, -1)

        return denoised_flat.reshape(original_shape).to(x.dtype)


def load_pca_filter(pca_paths: dict, device: str = 'cuda') -> dict:
    """Load PCA filter from numpy files and convert to CHW format.

    Args:
        pca_paths: dict with keys 'eigenvectors', 'eigenvalues', 'mean'
        device: target device

    Returns:
        Filter dictionary for GaussianDenoiser
    """
    eigenvectors = np.load(pca_paths['eigenvectors'])
    eigenvalues = np.load(pca_paths['eigenvalues'])
    mean = np.load(pca_paths['mean'])

    # Infer image dimensions
    D = mean.shape[0]
    for res in [256, 32, 64, 128]:
        if D == res * res * 3:
            img_resolution = res
            break
    else:
        raise ValueError(f"Cannot infer image resolution from D={D}")

    # Convert mean from HWC to CHW
    mean_chw = mean.reshape(img_resolution, img_resolution, 3).transpose(2, 0, 1).flatten()

    # Convert eigenvectors from HWC to CHW
    eigenvectors_chw = np.zeros_like(eigenvectors)
    for i in range(eigenvectors.shape[1]):
        eigenvectors_chw[:, i] = eigenvectors[:, i].reshape(
            img_resolution, img_resolution, 3
        ).transpose(2, 0, 1).flatten()

    return {
        'eigenvectors': torch.from_numpy(eigenvectors_chw).float().to(device),
        'eigenvalues': torch.from_numpy(eigenvalues).float().to(device),
        'mean_patch': torch.from_numpy(mean_chw).float().to(device),
        'img_resolution': img_resolution,
        'img_channels': 3,
    }


def load_gaussian_filter(path: str, device: str = 'cuda') -> Dict[str, torch.Tensor]:
    """Load Gaussian filter from .pt file."""
    return torch.load(path, map_location=device)


class GaussianX0Guidance:
    """X0-level Gaussian guidance using Wiener denoiser difference.

    This applies guidance at the x0 (denoised) level:
        guidance = strength * (denoiser_class(x_edm, sigma) - denoiser_full(x_edm, sigma))

    Args:
        class_filter_path: Path to class-specific PCA filter
        full_filter_path: Path to full-dataset PCA filter
        strength: Guidance strength (typically 3.0-7.0)
        window: GuidanceWindow for step/sigma range control
        num_components: Number of PCA components to use
        device: Target device
    """

    def __init__(
        self,
        class_filter_path: str,
        full_filter_path: str,
        strength: float = 5.0,
        window: Optional[GuidanceWindow] = None,
        num_components: Optional[int] = None,
        device: str = "cuda",
    ):
        self.strength = strength
        self.window = window or GuidanceWindow()
        self.device = device

        # Load filters
        class_path = Path(class_filter_path)
        full_path = Path(full_filter_path)

        if class_path.suffix == '.pt':
            class_filter = load_gaussian_filter(str(class_path), device=device)
        else:
            class_filter = load_pca_filter(
                {'eigenvectors': class_filter_path,
                 'eigenvalues': str(class_path).replace('eigenvectors', 'eigenvalues'),
                 'mean': str(class_path).replace('eigenvectors', 'mean')},
                device=device
            )

        if full_path.suffix == '.pt':
            full_filter = load_gaussian_filter(str(full_path), device=device)
        else:
            full_filter = load_pca_filter(
                {'eigenvectors': full_filter_path,
                 'eigenvalues': str(full_path).replace('eigenvectors', 'eigenvalues'),
                 'mean': str(full_path).replace('eigenvectors', 'mean')},
                device=device
            )

        # Create denoisers with Wiener mode
        self.class_denoiser = GaussianDenoiser(
            class_filter,
            num_components=num_components,
            denoising_mode='wiener'
        ).to(device)

        self.full_denoiser = GaussianDenoiser(
            full_filter,
            num_components=num_components,
            denoising_mode='wiener'
        ).to(device)

    def compute_guidance(
        self,
        x: torch.Tensor,
        alpha_prod_t: Union[torch.Tensor, float],
        step_idx: int,
        sigma: Optional[float] = None,
    ) -> Optional[torch.Tensor]:
        """Compute x0-level guidance.

        Args:
            x: Current DDPM-space sample x_t [B, C, H, W]
            alpha_prod_t: Alpha bar at current timestep
            step_idx: Current DDIM step index
            sigma: EDM-style sigma

        Returns:
            Guidance tensor to add to x0 estimate, or None if outside window
        """
        # Convert alpha_prod_t to scalar
        if torch.is_tensor(alpha_prod_t):
            ab_t = float(alpha_prod_t.item())
        else:
            ab_t = float(alpha_prod_t)

        # Compute sigma if not provided
        if sigma is None:
            sigma = float(np.sqrt((1 - ab_t) / ab_t))

        # Check window
        if not self.window.contains(step_idx, sigma):
            return None

        # Convert DDPM space to EDM space
        # DDPM: x_t = sqrt(ab) * x0 + sqrt(1-ab) * eps
        # EDM: x = x0 + sigma * eps
        # Therefore: x_edm = x_t / sqrt(ab)
        x_edm = x / np.sqrt(ab_t)

        # Apply Wiener denoisers in EDM space
        with torch.no_grad():
            d_class = self.class_denoiser.denoise(x_edm.float(), sigma)
            d_full = self.full_denoiser.denoise(x_edm.float(), sigma)

        # Compute guidance (difference of denoised estimates)
        guidance = self.strength * (d_class - d_full)

        return guidance.to(x.dtype)
