

from typing import override
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import models as tv

Tensor = torch.Tensor

def safe_resize(x: torch.Tensor, out_hw: tuple[int,int]) -> torch.Tensor:
    h0, w0 = x.shape[-2:]
    h1, w1 = out_hw
    if h1 <= h0 and w1 <= w0:
        return F.adaptive_avg_pool2d(x, (h1, w1))
    y = x
    while y.shape[-2] < h1 or y.shape[-1] < w1:
        step_h = min(h1, max(y.shape[-2] * 2, h1))
        step_w = min(w1, max(y.shape[-1] * 2, w1))
        y = F.interpolate(y, size=(step_h, step_w), mode="bilinear", align_corners=False)
    return y


class HaarDWT2D(nn.Module):
    def __init__(self, levels: int = 1):
        super().__init__()
        self.levels = levels

        ll = torch.tensor([[0.5, 0.5],[0.5, 0.5]], dtype=torch.float32)
        lh = torch.tensor([[0.5, 0.5],[-0.5,-0.5]], dtype=torch.float32)  
        hl = torch.tensor([[0.5,-0.5],[0.5,-0.5]], dtype=torch.float32) 
        hh = torch.tensor([[0.5,-0.5],[-0.5, 0.5]], dtype=torch.float32) 
        k = torch.stack([ll, lh, hl, hh], dim=0)  
        self.register_buffer("k2d", k[None, ...]) 

    @override
    def forward(self, x: Tensor) -> list[dict[str, Tensor]]:
        
        N, C, H, W = x.shape
        k = self.k2d  
        out = []
        cur = x
        for _ in range(self.levels):
           
            weight = k.expand(C, -1, -1, -1).reshape(4*C, 1, 2, 2)
            y = F.conv2d(cur, weight, stride=2, padding=0, groups=C) 
            y = y.view(N, C, 4, y.shape[-2], y.shape[-1])             
            LL, LH, HL, HH = y[:, :, 0], y[:, :, 1], y[:, :, 2], y[:, :, 3]  
            out.append({'LL': LL, 'LH': LH, 'HL': HL, 'HH': HH})
            cur = LL  
        return out


class LowpassFilter2D(nn.Module):
    kernel: Tensor

    def __init__(self):
        super().__init__()
        kernel_1d = torch.tensor([0.25, 0.5, 0.25], dtype=torch.float32)
        kernel_2d = kernel_1d[:, None] * kernel_1d[None, :]
        self.register_buffer("kernel", kernel_2d[None, None, :, :])

    @override
    def forward(self, x, stride=1):
        kernel = self.kernel.expand((x.shape[1], 1, -1, -1))
        x = F.conv2d(x, kernel, stride=stride, padding=1, groups=x.shape[1])  
        return x


class MultiLevelStats(nn.Module):
    def __init__(self, num_levels=4):
        super().__init__()
        self.num_levels = num_levels
        self.lowpass = LowpassFilter2D()

    @override
    def forward(self, x):
        squared = x**2
        means = []
        variances = []
        for _ in range(self.num_levels):
            m = self.lowpass(x, stride=1)
            p = self.lowpass(squared, stride=1)
            means.append(m)
            variances.append(p - m**2)
            x = m[..., ::2, ::2]
            squared = p[..., ::2, ::2]
        return means, variances



class WassersteinDistortionFeature(nn.Module):
    def __init__(self, num_levels: int = 5):
        super().__init__()
        self.multi_level_stats = MultiLevelStats(num_levels)
        self.num_levels = num_levels
        self.lowpass = LowpassFilter2D()

    @override
    def forward(
        self,
        features_a: Tensor,
        features_b: Tensor,
        log2_sigma: Tensor,
    ) -> Tensor:
       
        mean_pyr_a, var_pyr_a = self.multi_level_stats(features_a)
        mean_pyr_b, var_pyr_b = self.multi_level_stats(features_b)
        wd_maps = [torch.square(features_a - features_b)]
        for i in range(self.num_levels):
            std_pyr_a_i = torch.sqrt(torch.clamp(var_pyr_a[i], min=1e-8))
            std_pyr_b_i = torch.sqrt(torch.clamp(var_pyr_b[i], min=1e-8))
            square_mu = torch.square(mean_pyr_a[i] - mean_pyr_b[i])
            square_scale = torch.square(std_pyr_a_i - std_pyr_b_i)
            wd_maps.append(square_mu + square_scale)

        wasserstein_dist = 0
        for i, wd_map in enumerate(wd_maps):
            weights_i = F.relu(1 - torch.abs(log2_sigma - i))
            if i > 0:
                log2_sigma = self.lowpass(log2_sigma, stride=2)
            wasserstein_dist += (weights_i * wd_map).mean()
           
        assert isinstance(wasserstein_dist, Tensor)
        return wasserstein_dist
    

class MultiscaleTruncatedVGG16(nn.Module):
   
    mean: Tensor
    std: Tensor

    def __init__(
        self,
        requires_grad=False,
        pretrained=True,
        truncate_slice=5,
        replace_with_avg_pooling=True,
    ):
       
        super().__init__()
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.num_slices = 5
        self.truncate_slice = truncate_slice
        if not 1 <= truncate_slice <= self.num_slices:
            raise ValueError(
                f"truncate_slice must be between 1 and {self.num_slices}, inclusive, "
                f"but is {truncate_slice}."
            )

        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        if self.truncate_slice >= 2:
            for x in range(4, 9):
                if replace_with_avg_pooling and isinstance(
                    vgg_pretrained_features[x], nn.MaxPool2d
                ):
                    self.slice2.add_module(str(x), nn.AvgPool2d(kernel_size=2, stride=2))
                else:
                    self.slice2.add_module(str(x), vgg_pretrained_features[x])
        if self.truncate_slice >= 3:
            for x in range(9, 16):
                if replace_with_avg_pooling and isinstance(
                    vgg_pretrained_features[x], nn.MaxPool2d
                ):
                    self.slice3.add_module(str(x), nn.AvgPool2d(kernel_size=2, stride=2))
                else:
                    self.slice3.add_module(str(x), vgg_pretrained_features[x])
        if self.truncate_slice >= 4:
            for x in range(16, 23):
                if replace_with_avg_pooling and isinstance(
                    vgg_pretrained_features[x], nn.MaxPool2d
                ):
                    self.slice4.add_module(str(x), nn.AvgPool2d(kernel_size=2, stride=2))
                else:
                    self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if self.truncate_slice >= 5:
            for x in range(23, 30):
                if replace_with_avg_pooling and isinstance(
                    vgg_pretrained_features[x], nn.MaxPool2d
                ):
                    self.slice5.add_module(str(x), nn.AvgPool2d(kernel_size=2, stride=2))
                else:
                    self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

        self.slice_names = ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
        self.valid_slices = self.slice_names[: self.truncate_slice]

        mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)
        self.lowpass = LowpassFilter2D()

    @override
    def forward(self, x: Tensor, num_scales: int = 3) -> list[Tensor]:
       
        x = (x - self.mean) / self.std
        features = [x]
        for _ in range(num_scales):
            h = self.slice1(x)
            h_relu1_2 = h
            output_slices = [h_relu1_2]
            if self.truncate_slice >= 2:
                h = self.slice2(h)
                h_relu2_2 = h
                output_slices.append(h_relu2_2)
            if self.truncate_slice >= 3:
                h = self.slice3(h)
                h_relu3_3 = h
                output_slices.append(h_relu3_3)
            if self.truncate_slice >= 4:
                h = self.slice4(h)
                h_relu4_3 = h
                output_slices.append(h_relu4_3)
            if self.truncate_slice >= 5:
                h = self.slice5(h)
                h_relu5_3 = h
                output_slices.append(h_relu5_3)
            features += output_slices
            x = self.lowpass(x, stride=2)

        return features


class VGG16WassersteinDistortion(nn.Module):
    
    def __init__(
        self,
        feature_net: str = "vgg16",
        num_levels: int = 5,
        grayscale: bool = False,
        normalize_center_to_zero: bool = False,
        
        adaptive_subband: bool = True,
        energy_ema: float = 0.95,
        uniform_mix: float = 1.0, 
    ):
        super().__init__()
        self.wasserstein_distortion_feature = WassersteinDistortionFeature(num_levels)
        self.grayscale = grayscale
        self.normalize_center_to_zero = normalize_center_to_zero
        if feature_net == "vgg16":
            truncate_slice = 5
            self.feature_backbone = MultiscaleTruncatedVGG16(
                requires_grad=False, pretrained=True, truncate_slice=truncate_slice
            )
            self.truncate_slice = truncate_slice
        else:
            raise ValueError(f"Unsupported feature network: {feature_net}.")
        self.dwt_levels = 1   
        self.dwt = HaarDWT2D(levels=self.dwt_levels)
       
        self.subband_weights = {'LL': 0.25, 'LH': 0.25, 'HL': 0.25, 'HH': 0.25}
        self.adaptive_subband = adaptive_subband
        self.energy_ema = float(energy_ema)
        
       
        self.uniform_mix = float(uniform_mix)

       
        self.register_buffer("energy_LL", torch.tensor(1.0))
        self.register_buffer("energy_LH", torch.tensor(1.0))
        self.register_buffer("energy_HL", torch.tensor(1.0))
        self.register_buffer("energy_HH", torch.tensor(1.0))
        self.MIN_BAND_HW = 0
        self.sigma_max = 16
        

        self.max_log2_sigma = float(np.log2(self.sigma_max))  
        
        self.k_logm    = 0.0          
        self.k_logm_by_band = {'LL':0.2, 'LH':0.15, 'HL':0.15, 'HH':0.1}
        self.k_ratio   = 1.0          
        alpha_dwt = 1.0 
    def _band_precise_loose(self, band_name: str, level: int, *,
                    base=(0.95, 1.15),   
                    mid=(0.88, 1.25),    
                    hi=(0.8, 1.35),     
                    depth_decay=0.06): 
       
        if band_name == 'LL':
            p, q = base
        elif band_name in ('LH', 'HL'):
            p, q = mid
        else:  # 'HH'
            p, q = hi
       
        decay = (level - 1) * depth_decay
        p = max(0.3, p - decay)  
        q = min(3.0, q + decay)  
        return p, q
       
    

    def _current_subband_weights(self) -> dict[str, float]:
        if not self.adaptive_subband:
            return self.subband_weights
      
        with torch.no_grad():
            Eb = torch.stack([self.energy_LL, self.energy_LH, self.energy_HL, self.energy_HH])
            Eb = torch.clamp(Eb, min=1e-8)
            Eb = Eb / Eb.sum()
           
            mix = self.uniform_mix
          

            w = (1.0 - mix) * Eb + mix * torch.tensor([0.25, 0.25, 0.25, 0.25], device=Eb.device)
        return {'LL': float(w[0]), 'LH': float(w[1]), 'HL': float(w[2]), 'HH': float(w[3])}
    
    def _update_energy(self, band_name: str, wd_band: Tensor):
        if not self.adaptive_subband:
            return
        with torch.no_grad():
            e = wd_band.mean().detach()
            a = self.energy_ema
            if band_name == 'LL':
                self.energy_LL.mul_(a).add_(e * (1.0 - a))
            elif band_name == 'LH':
                self.energy_LH.mul_(a).add_(e * (1.0 - a))
            elif band_name == 'HL':
                self.energy_HL.mul_(a).add_(e * (1.0 - a))
            elif band_name == 'HH':
                self.energy_HH.mul_(a).add_(e * (1.0 - a))
    @override
    def forward(
        self,
        pred: Tensor,
        gt: Tensor,
        log2_sigma: Tensor,
        saliency=Tensor,
        num_scales: int = 3,
        
    ) -> Tensor:
        wd_information = {}
        if self.grayscale:
            pred = pred.expand(-1, 3, -1, -1)
            gt = gt.expand(-1, 3, -1, -1)
        if self.normalize_center_to_zero:
            pred = pred * 2 - 1
            gt = gt * 2 - 1
        if pred.shape != gt.shape:
            raise ValueError(
                f"Predicted and ground truth images must have the same shape, "
                f"but got {pred.shape} and {gt.shape}."
            )
        feats_pred = self.feature_backbone(pred, num_scales=num_scales)
        feats_gt = self.feature_backbone(gt, num_scales=num_scales)
       
        wasserstein_dist = torch.zeros((), device=pred.device, dtype=pred.dtype)

        assert len(feats_pred) == len(feats_gt)
        for fp, fgt in zip(feats_pred, feats_gt):
           
            bands_p = self.dwt(fp)   
            bands_g = self.dwt(fgt)
           
            for lvl, (bp, bg) in enumerate(zip(bands_p, bands_g), start=1):
                for band_name, fp_band in bp.items():
                    fgt_band = bg[band_name]
                    h, w = fgt_band.shape[-2:]
                    if min(h, w) < self.MIN_BAND_HW:
                        continue  
                   
                    ls = safe_resize(log2_sigma.detach(), fgt_band.shape[-2:])
                 
                    ls = ls
                   
                    log_ratio_h = np.log2(log2_sigma.shape[-2] / fgt.shape[-2])
                    log_ratio_w = np.log2(log2_sigma.shape[-1] / fgt.shape[-1])
                    mean_log_ratio = (log_ratio_h + log_ratio_w) / 2

                  
                    delta = self.k_ratio * mean_log_ratio + (lvl - 1)
                    ls = torch.clamp(ls - delta, min=0.0, max=self.max_log2_sigma)
                    ls = torch.clamp(ls, min=0.0, max=self.max_log2_sigma)
                 
                    wd_band = self.wasserstein_distortion_feature(fp_band, fgt_band, ls)
                  
                    wd_information[band_name] = wd_band
                 
                    wasserstein_dist = wasserstein_dist + self.subband_weights[band_name] * wd_band
                    
            
        
        assert isinstance(wasserstein_dist, Tensor)
        return wasserstein_dist,wd_information
