import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
import os
import json
from typing import List, Tuple, Optional

class DifferentiableHistogram(nn.Module):
    """ differentiable histogram with vectorized operations"""
    
    def __init__(self, bins: int = 100, min_: float = 0.0, max_: float = 1.0):
        super().__init__()
        self.bins = bins
        self.min = min_
        self.max = max_
        self.eps = 1e-6
        
        # Pre-compute bin edges - moved to register_buffer for GPU efficiency
        self.register_buffer('bin_edges', torch.linspace(min_, max_, bins + 1))
        
    def hist_vectorized(self, x: torch.Tensor) -> torch.Tensor:
        """Vectorized histogram computation"""
        # Flatten input while preserving batch and channel dimensions
        if x.dim() == 4:  # [B, C, H, W] -> [B, C, H*W]
            x_flat = x.flatten(start_dim=2)
        else:  # Already flattened
            x_flat = x
        
        # Use torch.histc for efficient histogram computation
        # Note: torch.histc works on flattened tensors, so we need to process each batch/channel
        batch_size, channels, spatial_size = x_flat.shape
        
        # Reshape to [B*C, spatial_size] for batch processing
        x_reshaped = x_flat.view(batch_size * channels, spatial_size)
        
        # Compute histograms for all batch*channel combinations at once
        hists = []
        for i in range(batch_size * channels):
            # Use torch.histc for efficient computation
            hist = torch.histc(x_reshaped[i], bins=self.bins, min=self.min, max=self.max)
            hists.append(hist)
        
        # Stack and reshape back to [B, C, bins]
        hists = torch.stack(hists).view(batch_size, channels, self.bins)
        
        # Normalize by spatial size
        hists = hists / spatial_size
        
        return hists
    
    def forward(self, p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        """ forward pass with vectorized operations"""
        # Ensure same spatial size
        if p.shape != q.shape:
            q = F.interpolate(q, size=p.shape[-2:], mode='bilinear', align_corners=False)
        
        # Compute histograms for all batch/channel combinations
        p_hists = self.hist_vectorized(p)  # [B, C, bins]
        q_hists = self.hist_vectorized(q)  # [B, C, bins]
        
        # Compute L1 loss across all histograms at once
        loss = F.l1_loss(p_hists, q_hists, reduction='sum')
        
        return loss


class DCTBlockwise(nn.Module):
    """Highly optimized DCT blockwise computation"""
    
    def __init__(self, block_size: int = 32):
        super().__init__()
        self.block_size = block_size
        
    def unfold_tensor_optimized(self, x: torch.Tensor) -> torch.Tensor:
        """ tensor unfolding with minimal padding"""
        B, C, H, W = x.shape
        
        # Calculate required padding
        pad_h = (self.block_size - H % self.block_size) % self.block_size
        pad_w = (self.block_size - W % self.block_size) % self.block_size
        
        # Apply padding only if necessary
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h), mode='constant', value=0)
        
        # Use unfold for efficient block extraction
        patches = x.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        # Shape: [B, C, num_blocks_h, num_blocks_w, block_size, block_size]
        
        return patches
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ DCT computation with vectorization"""
        # Import here to avoid dependency issues if torch_dct not available
        try:
            import torch_dct as dct
        except ImportError:
            raise ImportError("torch_dct is required. Install with: pip install torch_dct")
        
        # Get patches efficiently
        patches = self.unfold_tensor_optimized(x)
        B, C, num_h, num_w, block_h, block_w = patches.shape
        
        # Reshape for batch DCT computation
        # [B, C, num_h, num_w, block_h, block_w] -> [B*C*num_h*num_w, block_h, block_w]
        patches_reshaped = patches.view(-1, block_h, block_w)
        
        # Apply DCT to all blocks at once (much faster than channel-wise loop)
        dct_patches = dct.dct_2d(patches_reshaped)
        
        # Reshape back to [B, C, num_h, num_w, block_h, block_w]
        dct_patches = dct_patches.view(B, C, num_h, num_w, block_h, block_w)
        
        # Permute and flatten to match original output format
        # [B, C, num_h, num_w, block_h, block_w] -> [B, C, num_h*block_h, num_w*block_w]
        dct_output = dct_patches.permute(0, 1, 2, 4, 3, 5)
        dct_output = dct_output.contiguous().view(B, C, num_h * block_h, num_w * block_w)
        
        return dct_output


class DCTHistogram(nn.Module):
    """ DCT-based histogram with efficient block processing"""
    
    def __init__(self, bins: int = 250, min_: float = -100.0, max_: float = 100.0, block_size: int = 32):
        super().__init__()
        self.bins = bins
        self.min = min_
        self.max = max_
        self.block_size = block_size
        
        # Initialize optimized DCT computer
        self.dct_computer = DCTBlockwise(block_size)
        
        # Initialize histogram computer
        self.hist_computer = DifferentiableHistogram(bins, min_, max_)
    
    def forward(self, p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        """ DCT histogram forward pass"""
        # Ensure same spatial size
        if p.shape != q.shape:
            q = F.interpolate(q, size=p.shape[-2:], mode='bilinear', align_corners=False)
        
        # Compute DCT coefficients for both inputs
        p_dct = self.dct_computer(p)
        q_dct = self.dct_computer(q)
        
        # Compute histogram loss
        loss = self.hist_computer(p_dct, q_dct)
        
        return loss


class NeighborhoodLoss(nn.Module):
    """ neighborhood loss using the provided unfold function"""
    
    def __init__(self, neighborhood_size: int = 2, eps: float = 0.1):
        super().__init__()
        self.neighborhood_size = neighborhood_size
        self.eps = eps
        
    def unfold_neighborhoods(self, x: torch.Tensor) -> torch.Tensor:
        """Use the provided optimized unfold function"""
        # Import the unfold function from your module
        from flatten_faster import unfold_tensor
        
        patches, unfold_shape = unfold_tensor(x, self.neighborhood_size, self.neighborhood_size)
        return patches
    
    def compute_gaussian_params_batch(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """ batch computation of Gaussian parameters"""
        # x shape: [B, C, feature_dim, num_patches]
        B, C, feature_dim, num_patches = x.shape
        
        # Compute mean and center data
        mean = x.mean(dim=-1)  # [B, C, feature_dim]
        centered = x - mean.unsqueeze(-1)  # [B, C, feature_dim, num_patches]
        
        # Vectorized covariance computation
        # For numerical stability, convert to double precision
        centered = centered.double()
        
        # Batch matrix multiplication for covariance
        # [B, C, feature_dim, num_patches] @ [B, C, num_patches, feature_dim] -> [B, C, feature_dim, feature_dim]
        cov = torch.matmul(centered, centered.transpose(-2, -1)) / (num_patches - 1)
        
        return mean, cov
    
    def kl_divergence_batch(self, mean0: torch.Tensor, cov0: torch.Tensor,
                           mean1: torch.Tensor, cov1: torch.Tensor) -> torch.Tensor:
        """ batch KL divergence computation"""
        B, C, feature_dim = mean0.shape
        
        # Add regularization
        eye = torch.eye(feature_dim, device=cov0.device, dtype=cov0.dtype)
        cov0_reg = cov0 + self.eps * eye
        cov1_reg = cov1 + self.eps * eye
        
        # Batch inverse using torch.linalg.inv (more stable than torch.inverse)
        try:
            cov1_inv = torch.linalg.inv(cov1_reg)
        except:
            # Fallback: add more regularization
            cov1_reg = cov1 + (self.eps * 10) * eye
            cov1_inv = torch.linalg.inv(cov1_reg)
        
        # Vectorized computation of KL divergence components
        # Trace term: tr(Σ1^-1 Σ0)
        trace_term = torch.diagonal(torch.matmul(cov1_inv, cov0_reg), dim1=-2, dim2=-1).sum(dim=-1)
        
        # Determinant term: log(|Σ1|/|Σ0|)
        det1 = torch.linalg.det(cov1_reg)
        det0 = torch.linalg.det(cov0_reg)
        
        # Clamp determinants to avoid log(0)
        det1 = torch.clamp(det1, min=1e-10)
        det0 = torch.clamp(det0, min=1e-10)
        det_term = torch.log(det1 / det0)
        
        # Quadratic term: (μ1 - μ0)^T Σ1^-1 (μ1 - μ0)
        diff = (mean1 - mean0).double()  # [B, C, feature_dim]
        quad_term = torch.matmul(
            torch.matmul(diff.unsqueeze(-2), cov1_inv), 
            diff.unsqueeze(-1)
        ).squeeze(-1).squeeze(-1)
        
        # KL divergence: 0.5 * (tr + det + quad - k)
        kl = 0.5 * (trace_term + det_term + quad_term - feature_dim)
        
        # Handle potential NaN/Inf values
        kl = torch.nan_to_num(kl, nan=0.0, posinf=1e6, neginf=-1e6)
        
        return kl.sum()
    
    def forward(self, p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        """ neighborhood loss forward pass"""
        # Ensure same spatial size
        if p.shape != q.shape:
            q = F.interpolate(q, size=p.shape[-2:], mode='bilinear', align_corners=False)
        
        # Extract neighborhoods using optimized unfold
        p_neighborhoods = self.unfold_neighborhoods(p)
        q_neighborhoods = self.unfold_neighborhoods(q)
        
        # Compute Gaussian parameters
        p_mean, p_cov = self.compute_gaussian_params_batch(p_neighborhoods)
        q_mean, q_cov = self.compute_gaussian_params_batch(q_neighborhoods)
        
        # Compute KL divergence
        kl_loss = self.kl_divergence_batch(p_mean, p_cov, q_mean, q_cov)
        
        return kl_loss


class NeighborhoodLoss(nn.Module):
    """ neighborhood loss with efficient covariance computation"""
    
    def __init__(self, neighborhood_size: int = 2, eps: float = 0.1):
        super().__init__()
        self.neighborhood_size = neighborhood_size
        self.eps = eps
        
    def unfold_neighborhoods(self, x: torch.Tensor) -> torch.Tensor:
        """Efficient neighborhood extraction using unfold"""
        B, C, H, W = x.shape
        
        # Use unfold to extract neighborhoods efficiently
        neighborhoods = F.unfold(
            x, 
            kernel_size=self.neighborhood_size, 
            stride=self.neighborhood_size,
            padding=0
        )
        
        # Reshape to [B, C, neighborhood_size^2, num_patches]
        neighborhoods = neighborhoods.view(
            B, C, self.neighborhood_size * self.neighborhood_size, -1
        )
        
        return neighborhoods
    
    def compute_gaussian_params_vectorized(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Vectorized computation of mean and covariance"""
        # x shape: [B, C, feature_dim, num_patches]
        
        # Compute mean across patches
        mean = x.mean(dim=-1, keepdim=True)  # [B, C, feature_dim, 1]
        
        # Center the data
        centered = x - mean  # [B, C, feature_dim, num_patches]
        
        # Compute covariance matrices efficiently
        # For each batch and channel: cov = (centered @ centered.T) / (num_patches - 1)
        B, C, feature_dim, num_patches = centered.shape
        
        covariances = []
        for b in range(B):
            batch_covs = []
            for c in range(C):
                centered_bc = centered[b, c]  # [feature_dim, num_patches]
                cov = torch.mm(centered_bc, centered_bc.t()) / (num_patches - 1)
                batch_covs.append(cov)
            covariances.append(torch.stack(batch_covs))
        
        covariances = torch.stack(covariances)  # [B, C, feature_dim, feature_dim]
        
        return mean.squeeze(-1), covariances
    
    def kl_divergence_vectorized(self, mean0: torch.Tensor, cov0: torch.Tensor,
                                mean1: torch.Tensor, cov1: torch.Tensor) -> torch.Tensor:
        """Vectorized KL divergence computation"""
        B, C, feature_dim = mean0.shape
        
        # Add regularization to avoid singular matrices
        eye = torch.eye(feature_dim, device=cov0.device, dtype=cov0.dtype)
        cov0_reg = cov0 + self.eps * eye.unsqueeze(0).unsqueeze(0)
        cov1_reg = cov1 + self.eps * eye.unsqueeze(0).unsqueeze(0)
        
        # Compute KL divergence for each batch and channel
        kl_divs = []
        for b in range(B):
            batch_kl = []
            for c in range(C):
                # Extract matrices for this batch/channel
                mu0, mu1 = mean0[b, c], mean1[b, c]
                S0, S1 = cov0_reg[b, c].double(), cov1_reg[b, c].double()
                
                # Compute KL divergence components
                try:
                    S1_inv = torch.inverse(S1)
                    
                    # Trace term
                    trace_term = torch.trace(S1_inv @ S0)
                    
                    # Determinant term
                    det_term = torch.log(torch.det(S1) / torch.det(S0))
                    
                    # Quadratic term
                    diff = (mu1 - mu0).double()
                    quad_term = diff.t() @ S1_inv @ diff
                    
                    # KL divergence
                    kl = 0.5 * (trace_term + det_term + quad_term - feature_dim)
                    batch_kl.append(kl)
                    
                except:
                    # Fallback for numerical issues
                    batch_kl.append(torch.tensor(0.0, device=cov0.device))
            
            kl_divs.append(torch.stack(batch_kl))
        
        return torch.stack(kl_divs).sum()
    
    def forward(self, p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        """ neighborhood loss forward pass"""
        # Ensure same spatial size
        if p.shape != q.shape:
            q = F.interpolate(q, size=p.shape[-2:], mode='bilinear', align_corners=False)
        
        # Extract neighborhoods
        p_neighborhoods = self.unfold_neighborhoods(p)
        q_neighborhoods = self.unfold_neighborhoods(q)
        
        # Compute Gaussian parameters
        p_mean, p_cov = self.compute_gaussian_params_vectorized(p_neighborhoods)
        q_mean, q_cov = self.compute_gaussian_params_vectorized(q_neighborhoods)
        
        # Compute KL divergence
        kl_loss = self.kl_divergence_vectorized(p_mean, p_cov, q_mean, q_cov)
        
        return kl_loss


class TTAPipeline:
    """ TTA pipeline with efficient loss computation and proper integration"""
    
    def __init__(self, img_weight: float = 1.0, dct_weight: float = 1.0, 
                 nbh_weight: float = 0.0, learning_rate: float = 1.0,
                 dct_block_size: int = 32, neighborhood_size: int = 2):
        self.img_weight = img_weight
        self.dct_weight = dct_weight
        self.nbh_weight = nbh_weight
        self.learning_rate = learning_rate
        
        # Initialize optimized loss functions
        self.img_loss = DifferentiableHistogram(bins=100, min_=0.0, max_=1.0)
        self.dct_loss = DCTHistogram(bins=250, min_=-100.0, max_=100.0, 
                                             block_size=dct_block_size)
        if nbh_weight > 0:
            self.nbh_loss = NeighborhoodLoss(neighborhood_size=neighborhood_size)
        else:
            self.nbh_loss = None
        
    def extract_patches_batch(self, pred_xstart: torch.Tensor, garment: torch.Tensor,
                             pcoords: List, gcoords: List) -> Tuple[torch.Tensor, torch.Tensor]:
        """ batch patch extraction"""
        ysp, xsp = pcoords
        ysg, xsg = gcoords
        batch_size = pred_xstart.shape[0]
        
        # Pre-allocate lists
        image_patches = []
        garment_patches = []
        
        for i in range(batch_size):
            # Extract coordinates
            py_min, py_max = int(ysp[0][i]), int(ysp[1][i])
            px_min, px_max = int(xsp[0][i]), int(xsp[1][i])
            gy_min, gy_max = int(ysg[0][i]), int(ysg[1][i])
            gx_min, gx_max = int(xsg[0][i]), int(xsg[1][i])
            
            # Extract patches
            image_patch = pred_xstart[i:i+1, :, py_min:py_max, px_min:px_max]
            garment_patch = garment[i:i+1, :, gy_min:gy_max, gx_min:gx_max]
            
            image_patches.append(image_patch)
            garment_patches.append(garment_patch)
        
        # Stack patches efficiently
        image_patches = torch.cat(image_patches, dim=0)
        garment_patches = torch.cat(garment_patches, dim=0)
        
        return image_patches, garment_patches
    
    def compute_losses_vectorized(self, garment_patches: torch.Tensor, 
                                 image_patches: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """Vectorized loss computation"""
        # Normalize inputs (convert from [-1,1] to [0,1])
        garment_norm = torch.clamp((garment_patches + 1) * 0.5, min=0.0, max=1.0)
        image_norm = torch.clamp((image_patches + 1) * 0.5, min=0.0, max=1.0)
        
        total_loss = 0.0
        loss_dict = {}
        
        # Image space histogram loss
        if self.img_weight > 0:
            img_loss_val = self.img_loss(garment_norm, image_norm)
            total_loss += self.img_weight * img_loss_val
            loss_dict['img_loss'] = img_loss_val.item()
        else:
            loss_dict['img_loss'] = 0.0
        
        # DCT frequency domain loss
        if self.dct_weight > 0:
            dct_loss_val = self.dct_loss(garment_norm, image_norm)
            total_loss += self.dct_weight * dct_loss_val
            loss_dict['dct_loss'] = dct_loss_val.item()
        else:
            loss_dict['dct_loss'] = 0.0
        
        # Neighborhood loss
        if self.nbh_weight > 0 and self.nbh_loss is not None:
            nbh_loss_val = self.nbh_loss(garment_norm, image_norm)
            total_loss += self.nbh_weight * nbh_loss_val
            loss_dict['nbh_loss'] = nbh_loss_val.item()
        else:
            loss_dict['nbh_loss'] = 0.0
        
        loss_dict['total_loss'] = total_loss.item() if isinstance(total_loss, torch.Tensor) else total_loss
        
        return total_loss, loss_dict
    
    def optimize_step(self, pred_xstart_latent: torch.Tensor, garment: torch.Tensor,
                     pcoords: List, gcoords: List, decode_fn: callable,
                     timestep_idx: int = 0) -> Tuple[torch.Tensor, dict]:
        """Single TTA optimization step"""
        
        # Ensure gradients are enabled for latents
        pred_xstart_latent.requires_grad_(True)
        
        # Decode latents to image space
        with torch.enable_grad():
            pred_xstart = decode_fn(pred_xstart_latent)
            
            # Extract patches for loss computation
            image_patches, garment_patches = self.extract_patches_batch(
                pred_xstart, garment, pcoords, gcoords
            )
            
            # Compute losses
            total_loss, loss_dict = self.compute_losses_vectorized(garment_patches, image_patches)
            
            # Skip optimization if loss is zero or very small
            if isinstance(total_loss, torch.Tensor) and total_loss.item() > 1e-8:
                # Compute gradients
                grad = torch.autograd.grad(
                    outputs=total_loss,
                    inputs=pred_xstart_latent,
                    retain_graph=False,
                    create_graph=False,
                    allow_unused=True
                )[0]
                
                if grad is not None and not torch.isnan(grad).any():
                    # Apply gradient update with clipping for stability
                    grad_norm = torch.norm(grad)
                    if grad_norm > 1.0:  # Gradient clipping
                        grad = grad / grad_norm
                    
                    # Update latents
                    with torch.no_grad():
                        updated_latents = pred_xstart_latent - self.learning_rate * grad
                    
                    return updated_latents.detach(), loss_dict
        
        # Return original latents if optimization failed
        return pred_xstart_latent.detach(), loss_dict


def create_optimized_tta_step(img_weight: float = 1.0, dct_weight: float = 1.0, 
                             nbh_weight: float = 0.0, learning_rate: float = 1.0,
                             dct_block_size: int = 32, neighborhood_size: int = 2):
    """Factory function to create optimized TTA step function"""
    
    tta_optimizer = TTAPipeline(
        img_weight=img_weight,
        dct_weight=dct_weight,
        nbh_weight=nbh_weight,
        learning_rate=learning_rate,
        dct_block_size=dct_block_size,
        neighborhood_size=neighborhood_size
    )
    
    def optimized_step(pred_xstart_latent, garment, pcoords, gcoords, decode_fn, timestep_idx=0):
        return tta_optimizer.optimize_step(
            pred_xstart_latent, garment, pcoords, gcoords, decode_fn, timestep_idx
        )
    
    return optimized_step


# Drop-in replacement for your existing pipeline loop
def replace_tta_loop_in_pipeline():
    """
    Example of how to replace the TTA loop in your existing pipeline
    """
    
    # Initialize optimized TTA (do this once outside the timestep loop)
    optimized_tta_step = create_optimized_tta_step(
        img_weight=1.0,  # kwargs['img_w']
        dct_weight=1.0,  # kwargs['dct_w'] 
        nbh_weight=0.0,  # kwargs.get('nbh_w', 0.0) 
        learning_rate=1.0,
        dct_block_size=32,
        neighborhood_size=2
    )
    
    # Inside your timestep loop, replace the loss computation section with:
    """
    # Original code you had:
    # for im_idx in range(b):
    #     # ... loss computation and gradient update
    #     latents[im_idx][None] -= scale*norm_grad[im_idx]
    
    # Replace with:
    decode_fn = lambda x: (
        self.decode_latents(x, intermediate_features) if self.emasc 
        else self.decode_latents(x)
    )
    
    updated_latents, loss_dict = optimized_tta_step(
        pred_xstart_latent, garment, pcoords, gcoords, decode_fn, i
    )
    latents = updated_latents
    
    # Store losses for logging
    for im_idx in range(b):
        losses["dct_loss"][im_idx].append(loss_dict['dct_loss'])
        losses["histoimage_loss"][im_idx].append(loss_dict['img_loss'])
        losses["total_loss"][im_idx].append(loss_dict['total_loss'])
    """
    
    return optimized_tta_step