import torch
import torch.nn as nn
import torch.fft
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import cv2
from scipy import signal
import pywt

class FrequencyDomainAnalyzer:
    """
    Comprehensive frequency domain analysis for image deblurring evaluation.
    
    Analyzes spectral characteristics, frequency responses, and blur patterns
    in both spatial and frequency domains for detailed model assessment.
    """
    
    def __init__(self, device: torch.device = None):
        self.device = device or torch.device('cpu')
        self.analysis_cache = {}
        
    def analyze_spectrum(self, image: torch.Tensor) -> Dict[str, np.ndarray]:
        """
        Comprehensive spectral analysis of input image.
        
        Returns:
            Dictionary containing various spectral characteristics
        """
        if image.dim() == 4:
            image = image.squeeze(0)  # Remove batch dimension
        
        # Convert to numpy
        img_np = image.detach().cpu().numpy()
        
        results = {}
        
        # For each channel
        for ch in range(img_np.shape[0]):
            channel_img = img_np[ch]
            
            # 2D FFT
            fft_2d = np.fft.fft2(channel_img)
            fft_shifted = np.fft.fftshift(fft_2d)
            magnitude_spectrum = np.abs(fft_shifted)
            phase_spectrum = np.angle(fft_shifted)
            
            # Power spectral density
            psd = magnitude_spectrum ** 2
            
            # Radial average
            h, w = channel_img.shape
            center = (h // 2, w // 2)
            y, x = np.ogrid[:h, :w]
            r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
            r = r.astype(int)
            
            # Compute radial average
            tbin = np.bincount(r.ravel(), psd.ravel())
            nr = np.bincount(r.ravel())
            radial_profile = tbin / nr
            
            results[f'channel_{ch}'] = {
                'magnitude_spectrum': magnitude_spectrum,
                'phase_spectrum': phase_spectrum,
                'power_spectral_density': psd,
                'radial_profile': radial_profile
            }
        
        # Cross-channel analysis
        if img_np.shape[0] > 1:
            results['cross_channel'] = self._analyze_cross_channel_spectrum(img_np)
        
        return results
    
    def _analyze_cross_channel_spectrum(self, img_np: np.ndarray) -> Dict[str, np.ndarray]:
        """Analyze spectral relationships between channels."""
        cross_results = {}
        
        # Compute cross-power spectral density
        for i in range(img_np.shape[0]):
            for j in range(i + 1, img_np.shape[0]):
                fft_i = np.fft.fft2(img_np[i])
                fft_j = np.fft.fft2(img_np[j])
                
                cross_psd = fft_i * np.conj(fft_j)
                coherence = np.abs(cross_psd) / (np.abs(fft_i) * np.abs(fft_j) + 1e-8)
                
                cross_results[f'coherence_{i}_{j}'] = np.fft.fftshift(coherence)
        
        return cross_results
    
    def estimate_blur_kernel_frequency(self, blurred: torch.Tensor, sharp: torch.Tensor) -> Dict[str, np.ndarray]:
        """
        Estimate blur kernel characteristics in frequency domain.
        
        Uses the relationship: Blurred = Sharp * Kernel (in frequency domain: division)
        """
        # Convert to numpy
        blurred_np = blurred.squeeze().detach().cpu().numpy()
        sharp_np = sharp.squeeze().detach().cpu().numpy()
        
        results = {}
        
        # Process each channel
        for ch in range(min(blurred_np.shape[0], sharp_np.shape[0])):
            blurred_ch = blurred_np[ch] if blurred_np.ndim > 2 else blurred_np
            sharp_ch = sharp_np[ch] if sharp_np.ndim > 2 else sharp_np
            
            # FFT
            blurred_fft = np.fft.fft2(blurred_ch)
            sharp_fft = np.fft.fft2(sharp_ch)
            
            # Estimate kernel in frequency domain
            kernel_fft = blurred_fft / (sharp_fft + 1e-8)  # Regularization
            
            # Convert back to spatial domain
            kernel_spatial = np.fft.ifft2(kernel_fft).real
            kernel_spatial = np.fft.fftshift(kernel_spatial)
            
            # Analyze kernel characteristics
            kernel_magnitude = np.abs(kernel_fft)
            kernel_phase = np.angle(kernel_fft)
            
            results[f'channel_{ch}'] = {
                'kernel_fft_magnitude': np.fft.fftshift(kernel_magnitude),
                'kernel_fft_phase': np.fft.fftshift(kernel_phase),
                'kernel_spatial': kernel_spatial,
                'blur_severity': self._compute_blur_severity(kernel_magnitude)
            }
        
        return results
    
    def _compute_blur_severity(self, kernel_magnitude: np.ndarray) -> float:
        """Compute blur severity from kernel magnitude spectrum."""
        # High frequency energy ratio
        h, w = kernel_magnitude.shape
        center_region = kernel_magnitude[h//4:3*h//4, w//4:3*w//4]
        edge_region = kernel_magnitude.copy()
        edge_region[h//4:3*h//4, w//4:3*w//4] = 0
        
        center_energy = np.sum(center_region**2)
        total_energy = np.sum(kernel_magnitude**2)
        
        severity = center_energy / (total_energy + 1e-8)
        return severity
    
    def analyze_restoration_quality(
        self, 
        original: torch.Tensor, 
        restored: torch.Tensor
    ) -> Dict[str, float]:
        """
        Comprehensive frequency domain quality assessment.
        """
        orig_spectrum = self.analyze_spectrum(original)
        rest_spectrum = self.analyze_spectrum(restored)
        
        quality_metrics = {}
        
        # For each channel
        for ch_key in orig_spectrum.keys():
            if ch_key.startswith('channel_'):
                orig_psd = orig_spectrum[ch_key]['power_spectral_density']
                rest_psd = rest_spectrum[ch_key]['power_spectral_density']
                
                # Spectral similarity metrics
                spectral_correlation = np.corrcoef(orig_psd.flatten(), rest_psd.flatten())[0, 1]
                spectral_mse = np.mean((orig_psd - rest_psd) ** 2)
                
                # High-frequency restoration quality
                h, w = orig_psd.shape
                hf_mask = self._create_high_freq_mask(h, w)
                
                orig_hf_energy = np.sum(orig_psd * hf_mask)
                rest_hf_energy = np.sum(rest_psd * hf_mask)
                hf_restoration_ratio = rest_hf_energy / (orig_hf_energy + 1e-8)
                
                quality_metrics[f'{ch_key}_spectral_correlation'] = spectral_correlation
                quality_metrics[f'{ch_key}_spectral_mse'] = spectral_mse
                quality_metrics[f'{ch_key}_hf_restoration_ratio'] = hf_restoration_ratio
        
        # Overall quality
        quality_metrics['overall_spectral_quality'] = np.mean([
            v for k, v in quality_metrics.items() if 'spectral_correlation' in k
        ])
        
        return quality_metrics
    
    def _create_high_freq_mask(self, h: int, w: int, cutoff: float = 0.7) -> np.ndarray:
        """Create high-frequency mask for spectral analysis."""
        center = (h // 2, w // 2)
        y, x = np.ogrid[:h, :w]
        r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
        max_r = min(h, w) // 2
        
        mask = (r > cutoff * max_r).astype(float)
        return mask

class SpectralCharacteristics:
    """
    Advanced spectral characteristics analyzer for blur pattern recognition.
    
    Implements various spectral measures for detailed blur analysis including
    directional blur detection, motion blur analysis, and defocus estimation.
    """
    
    def __init__(self):
        self.blur_types = ['motion', 'defocus', 'gaussian', 'atmospheric']
        
    def classify_blur_type(self, image: torch.Tensor) -> Dict[str, float]:
        """
        Classify the type of blur present in the image using spectral analysis.
        """
        img_np = image.squeeze().detach().cpu().numpy()
        if img_np.ndim == 3:
            img_np = np.mean(img_np, axis=0)  # Convert to grayscale
        
        # Compute 2D FFT
        fft_2d = np.fft.fft2(img_np)
        magnitude_spectrum = np.abs(np.fft.fftshift(fft_2d))
        
        blur_scores = {}
        
        # Motion blur detection
        blur_scores['motion'] = self._detect_motion_blur(magnitude_spectrum)
        
        # Defocus blur detection  
        blur_scores['defocus'] = self._detect_defocus_blur(magnitude_spectrum)
        
        # Gaussian blur detection
        blur_scores['gaussian'] = self._detect_gaussian_blur(magnitude_spectrum)
        
        # Atmospheric turbulence
        blur_scores['atmospheric'] = self._detect_atmospheric_blur(magnitude_spectrum)
        
        # Normalize scores
        total_score = sum(blur_scores.values())
        if total_score > 0:
            blur_scores = {k: v / total_score for k, v in blur_scores.items()}
        
        return blur_scores
    
    def _detect_motion_blur(self, magnitude_spectrum: np.ndarray) -> float:
        """Detect motion blur using directional analysis."""
        # Radon transform for directional analysis
        h, w = magnitude_spectrum.shape
        angles = np.arange(0, 180, 1)
        
        # Simplified directional energy analysis
        center_y, center_x = h // 2, w // 2
        max_directional_energy = 0
        
        for angle in range(0, 180, 10):
            # Create directional mask
            rad = np.radians(angle)
            
            # Sample along direction
            length = min(h, w) // 4
            y_coords = center_y + np.arange(-length, length) * np.sin(rad)
            x_coords = center_x + np.arange(-length, length) * np.cos(rad)
            
            # Keep coordinates in bounds
            valid_mask = (y_coords >= 0) & (y_coords < h) & (x_coords >= 0) & (x_coords < w)
            y_coords = y_coords[valid_mask].astype(int)
            x_coords = x_coords[valid_mask].astype(int)
            
            if len(y_coords) > 0:
                directional_energy = np.mean(magnitude_spectrum[y_coords, x_coords])
                max_directional_energy = max(max_directional_energy, directional_energy)
        
        # Compute average energy for comparison
        avg_energy = np.mean(magnitude_spectrum)
        motion_score = max_directional_energy / (avg_energy + 1e-8)
        
        return min(1.0, max(0.0, (motion_score - 1.0) / 2.0))
    
    def _detect_defocus_blur(self, magnitude_spectrum: np.ndarray) -> float:
        """Detect defocus blur using radial symmetry analysis."""
        h, w = magnitude_spectrum.shape
        center = (h // 2, w // 2)
        
        # Compute radial profile
        y, x = np.ogrid[:h, :w]
        r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
        r = r.astype(int)
        
        # Radial average
        max_r = min(h, w) // 2
        radial_profile = []
        
        for radius in range(max_r):
            mask = (r == radius)
            if np.any(mask):
                radial_profile.append(np.mean(magnitude_spectrum[mask]))
        
        radial_profile = np.array(radial_profile)
        
        # Analyze decay pattern
        if len(radial_profile) > 10:
            # Fit exponential decay
            x_vals = np.arange(len(radial_profile))
            # Simple decay analysis
            decay_rate = np.polyfit(x_vals[1:], np.log(radial_profile[1:] + 1e-8), 1)[0]
            defocus_score = abs(decay_rate) / 0.1  # Normalize
        else:
            defocus_score = 0.0
        
        return min(1.0, max(0.0, defocus_score))
    
    def _detect_gaussian_blur(self, magnitude_spectrum: np.ndarray) -> float:
        """Detect Gaussian blur characteristics."""
        # Gaussian blur creates smooth, symmetric frequency response
        h, w = magnitude_spectrum.shape
        
        # Compute second moment (spread) of spectrum
        y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
        center_y, center_x = h // 2, w // 2
        
        # Normalized coordinates
        y_norm = (y - center_y) / (h // 2)
        x_norm = (x - center_x) / (w // 2)
        
        # Weighted second moments
        weights = magnitude_spectrum / (np.sum(magnitude_spectrum) + 1e-8)
        
        moment_xx = np.sum(weights * x_norm**2)
        moment_yy = np.sum(weights * y_norm**2)
        moment_xy = np.sum(weights * x_norm * y_norm)
        
        # Isotropy measure (closer to 1 = more isotropic = more Gaussian-like)
        isotropy = min(moment_xx, moment_yy) / (max(moment_xx, moment_yy) + 1e-8)
        gaussian_score = isotropy * (1 - abs(moment_xy))
        
        return gaussian_score
    
    def _detect_atmospheric_blur(self, magnitude_spectrum: np.ndarray) -> float:
        """Detect atmospheric turbulence patterns."""
        # Atmospheric blur typically has irregular, turbulent patterns
        # Use texture analysis in frequency domain
        
        # Compute local variance in frequency domain
        from scipy import ndimage
        
        # Local standard deviation filter
        local_std = ndimage.generic_filter(magnitude_spectrum, np.std, size=5)
        
        # Turbulence is characterized by high local variation
        turbulence_measure = np.mean(local_std) / (np.mean(magnitude_spectrum) + 1e-8)
        
        # Normalize and clamp
        atmospheric_score = min(1.0, max(0.0, turbulence_measure / 0.5))
        
        return atmospheric_score

class BlurPatternAnalyzer:
    """
    Specialized analyzer for detailed blur pattern characterization.
    
    Provides tools for analyzing blur kernel shapes, directions,
    and characteristics for comprehensive blur understanding.
    """
    
    def __init__(self):
        self.pattern_cache = {}
        
    def analyze_blur_pattern(self, blur_kernel: np.ndarray) -> Dict[str, any]:
        """
        Comprehensive blur pattern analysis.
        
        Args:
            blur_kernel: Estimated or known blur kernel
            
        Returns:
            Dictionary with detailed pattern characteristics
        """
        results = {
            'shape_analysis': self._analyze_kernel_shape(blur_kernel),
            'directional_analysis': self._analyze_kernel_direction(blur_kernel),
            'size_analysis': self._analyze_kernel_size(blur_kernel),
            'symmetry_analysis': self._analyze_kernel_symmetry(blur_kernel)
        }
        
        # Overall pattern classification
        results['pattern_type'] = self._classify_pattern_type(results)
        
        return results
    
    def _analyze_kernel_shape(self, kernel: np.ndarray) -> Dict[str, float]:
        """Analyze the geometric shape of the blur kernel."""
        # Threshold kernel to get main structure
        threshold = 0.1 * np.max(kernel)
        binary_kernel = (kernel > threshold).astype(float)
        
        # Compute moments for shape analysis
        moments = cv2.moments(binary_kernel)
        
        shape_metrics = {}
        
        if moments['m00'] > 0:
            # Centroid
            cx = moments['m10'] / moments['m00']
            cy = moments['m01'] / moments['m00']
            
            # Central moments for shape characterization
            mu20 = moments['mu20'] / moments['m00']
            mu02 = moments['mu02'] / moments['m00']
            mu11 = moments['mu11'] / moments['m00']
            
            # Eccentricity and orientation
            denominator = 4 * mu11**2 + (mu20 - mu02)**2
            if denominator > 0:
                eccentricity = np.sqrt(denominator) / (mu20 + mu02)
                orientation = 0.5 * np.arctan2(2 * mu11, mu20 - mu02)
            else:
                eccentricity = 0
                orientation = 0
            
            shape_metrics['eccentricity'] = eccentricity
            shape_metrics['orientation'] = orientation
            shape_metrics['compactness'] = (moments['m00'] * 4 * np.pi) / (cv2.arcLength(binary_kernel.astype(np.uint8), True)**2 + 1e-8)
        
        return shape_metrics
    
    def _analyze_kernel_direction(self, kernel: np.ndarray) -> Dict[str, float]:
        """Analyze directional characteristics of blur kernel."""
        # Compute gradient directions
        gy, gx = np.gradient(kernel)
        
        # Dominant gradient direction
        gradient_angles = np.arctan2(gy, gx)
        
        # Histogram of gradient directions
        hist, bin_edges = np.histogram(gradient_angles.flatten(), bins=36, range=(-np.pi, np.pi))
        
        # Find dominant direction
        dominant_bin = np.argmax(hist)
        dominant_angle = (bin_edges[dominant_bin] + bin_edges[dominant_bin + 1]) / 2
        
        # Measure directional consistency
        max_hist_value = np.max(hist)
        mean_hist_value = np.mean(hist)
        directional_strength = max_hist_value / (mean_hist_value + 1e-8)
        
        return {
            'dominant_direction': dominant_angle,
            'directional_strength': directional_strength,
            'directional_variance': np.var(gradient_angles)
        }
    
    def _analyze_kernel_size(self, kernel: np.ndarray) -> Dict[str, float]:
        """Analyze the effective size of the blur kernel."""
        # Effective support size
        threshold = 0.05 * np.max(kernel)
        support_mask = kernel > threshold
        
        # Bounding box
        rows, cols = np.where(support_mask)
        if len(rows) > 0:
            effective_height = np.max(rows) - np.min(rows) + 1
            effective_width = np.max(cols) - np.min(cols) + 1
        else:
            effective_height = effective_width = 0
        
        # Radius of gyration
        total_mass = np.sum(kernel)
        if total_mass > 0:
            y, x = np.mgrid[:kernel.shape[0], :kernel.shape[1]]
            cy = np.sum(kernel * y) / total_mass
            cx = np.sum(kernel * x) / total_mass
            
            radius_of_gyration = np.sqrt(np.sum(kernel * ((y - cy)**2 + (x - cx)**2)) / total_mass)
        else:
            radius_of_gyration = 0
        
        return {
            'effective_height': effective_height,
            'effective_width': effective_width,
            'radius_of_gyration': radius_of_gyration,
            'aspect_ratio': effective_width / (effective_height + 1e-8)
        }
    
    def _analyze_kernel_symmetry(self, kernel: np.ndarray) -> Dict[str, float]:
        """Analyze symmetry properties of the blur kernel."""
        h, w = kernel.shape
        
        # Horizontal symmetry
        left_half = kernel[:, :w//2]
        right_half = kernel[:, w//2:]
        right_half_flipped = np.fliplr(right_half)
        
        if left_half.shape == right_half_flipped.shape:
            horizontal_symmetry = np.corrcoef(left_half.flatten(), right_half_flipped.flatten())[0, 1]
        else:
            horizontal_symmetry = 0
        
        # Vertical symmetry
        top_half = kernel[:h//2, :]
        bottom_half = kernel[h//2:, :]
        bottom_half_flipped = np.flipud(bottom_half)
        
        if top_half.shape == bottom_half_flipped.shape:
            vertical_symmetry = np.corrcoef(top_half.flatten(), bottom_half_flipped.flatten())[0, 1]
        else:
            vertical_symmetry = 0
        
        # Radial symmetry (compare with radially averaged version)
        center = (h // 2, w // 2)
        y, x = np.ogrid[:h, :w]
        r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int)
        
        # Compute radial average
        radial_avg = np.zeros_like(kernel)
        for radius in range(min(h, w) // 2):
            mask = (r == radius)
            if np.any(mask):
                avg_val = np.mean(kernel[mask])
                radial_avg[mask] = avg_val
        
        radial_symmetry = np.corrcoef(kernel.flatten(), radial_avg.flatten())[0, 1]
        
        return {
            'horizontal_symmetry': horizontal_symmetry,
            'vertical_symmetry': vertical_symmetry,
            'radial_symmetry': radial_symmetry
        }
    
    def _classify_pattern_type(self, analysis_results: Dict) -> str:
        """Classify the overall pattern type based on analysis results."""
        shape = analysis_results['shape_analysis']
        direction = analysis_results['directional_analysis']
        symmetry = analysis_results['symmetry_analysis']
        
        # Decision tree for pattern classification
        if symmetry['radial_symmetry'] > 0.8:
            return 'defocus'
        elif direction['directional_strength'] > 2.0 and shape['eccentricity'] > 0.5:
            return 'motion'
        elif symmetry['horizontal_symmetry'] > 0.7 and symmetry['vertical_symmetry'] > 0.7:
            return 'gaussian'
        elif direction['directional_variance'] > 1.0:
            return 'atmospheric'
        else:
            return 'complex' 