import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft
import math
import os
from transformers import BertTokenizer, BertModel
from dinov2.models import vit_large
from scipy.ndimage import sobel


# ------------------------------
# Utility Functions (Matches Paper Notation)
# ------------------------------
def to_2tuple(x):
    return (x, x) if isinstance(x, int) else x


def normalize(x, dim=-1):
    """L2 normalization (used for embeddings in §3.2.3)"""
    return F.normalize(x, p=2, dim=dim)


def spectral_flatness(x):
    """Compute E(t) via DCT (Eq. in §3.3.1: spectral flatness penalty)
    Args:
        x: [B, N, d] Token features
    Returns:
        e_score: [B, N] Entropy score for noisy tokens
    """
    # DCT transform (frequency-domain analysis)
    x_dct = fft.fft(x, dim=-1).real  # [B, N, d]
    x_abs = torch.abs(x_dct)
    # Compute flat(u) = |u| / sum(|u|)
    flat = x_abs / (x_abs.sum(dim=-1, keepdim=True) + 1e-8)  # [B, N, d]
    # Compute E(t) = -1/Z * sum(log(1 + α·flat(u)))
    alpha = 1.0  # Paper does not specify α; set to 1.0
    return -torch.mean(torch.log(1 + alpha * flat), dim=-1)  # [B, N]


# ------------------------------
# 1. Sinkhorn Solver (Eq.2: Low-Rank Entropy-Regularized OT)
# ------------------------------
class SinkhornSolver(nn.Module):
    def __init__(self, reg=0.05, max_iter=50, rank=32):
        super().__init__()
        self.reg = reg  # ε in Eq.2 (entropy regularization)
        self.max_iter = max_iter  # Convergence threshold: ||T_t+1 - T_t||_F < δ
        self.rank = rank  # Low-rank approximation rank r (§3.2.1)

    def forward(self, C, a=None, b=None):
        """
        Args:
            C: [B, n, m] Cost matrix (1 - cos_sim, Eq.1)
            a: [B, n, 1] Uniform marginal for pixels (a_i = 1/M²)
            b: [B, 1, m] Token-frequency marginal (b_j = f_j / sum(f_j))
        Returns:
            T: [B, n, m] Optimal transport plan
        """
        B, n, m = C.shape
        # Default marginals (matches §3.2.1)
        if a is None:
            a = torch.ones(B, n, 1, device=C.device) / n
        if b is None:
            b = torch.ones(B, 1, m, device=C.device) / m

        # Low-rank initialization (§3.2.1)
        U = torch.randn(B, n, self.rank, device=C.device)
        V = torch.randn(B, m, self.rank, device=C.device)
        K = torch.exp(-C / self.reg)  # Kernel matrix

        # Iterative optimization (Sinkhorn algorithm)
        for _ in range(self.max_iter):
            # Update U: U = a · KV / sum(KV)
            KV = torch.matmul(K, V)
            U = a * KV / (torch.sum(KV, dim=1, keepdim=True) + 1e-8)
            # Update V: V = b · K^T U / sum(K^T U)
            KtU = torch.matmul(K.transpose(1, 2), U)
            V = b * KtU / (torch.sum(KtU, dim=1, keepdim=True) + 1e-8)

        # Reconstruct transport plan (low-rank product)
        return torch.matmul(U, V.transpose(1, 2))


# ------------------------------
# 2. RF-HGR Approximator (Eq.5-6: No Extra Terms)
# ------------------------------
class HGRApproximator(nn.Module):
    def __init__(self, d, k=256, sigma=0.1):
        super().__init__()
        self.d = d  # Feature dimension
        self.k = k  # RFF dimension (§3.2.2: k=256)
        self.sigma = sigma  # RBF bandwidth (§3.2.2: σ=0.1)

        # RFF parameters (Eq.5): Ω ~ N(0, σ²), b ~ Unif[0, 2π]
        self.Omega = nn.Parameter(torch.randn(k, d) * sigma)
        self.b = nn.Parameter(torch.rand(k) * 2 * math.pi)

    def rff_projection(self, x):
        """Eq.5: φ(x) = √(2/k) · cos(Ωx + b)"""
        x_proj = x @ self.Omega.T  # [B, N, k]
        x_proj += self.b[None, None, :]
        return torch.cos(x_proj) * math.sqrt(2 / self.k)

    def forward(self, x, y):
        """Eq.6: Ûρ = φ(x)⊤φ(y) / (||φ(x)||·||φ(y)||)"""
        phi_x = self.rff_projection(x)  # [B, N, k]
        phi_y = self.rff_projection(y)  # [B, M, k]

        numerator = torch.matmul(phi_x, phi_y.transpose(1, 2))  # [B, N, M]
        denom_x = torch.norm(phi_x, dim=2, keepdim=True)  # [B, N, 1]
        denom_y = torch.norm(phi_y, dim=2, keepdim=True).transpose(1, 2)  # [B, 1, M]

        return numerator / (denom_x * denom_y + 1e-8)  # [B, N, M] ∈ [-1, 1]


# ------------------------------
# 3. Differentiable Adaptive Routing (Eq.9-12: DAR)
# ------------------------------
class DualAdaptiveRouter(nn.Module):
    def __init__(self, lambda_A=0.5, lambda_G=0.3, lambda_E=0.2, rho=0.3):
        super().__init__()
        self.lambda_A = lambda_A  # Weight for cross-modal agreement (A(t))
        self.lambda_G = lambda_G  # Weight for gradient importance (G(t))
        self.lambda_E = lambda_E  # Weight for spectral flatness (E(t))
        self.rho = rho  # Target sparsity (§3.3.1: ρ=0.3)

    def cross_modal_agreement(self, x, cross_x):
        """Eq. A(t): Cross-modal alignment score (nearest neighbor)"""
        sim = F.cosine_similarity(x.unsqueeze(2), cross_x.unsqueeze(1), dim=-1)
        return sim.max(dim=-1)[0]  # [B, N]

    def gradient_importance(self, x, loss):
        """Eq. G(t): Task criticality (gradient norm of total loss)"""
        if self.training and loss is not None:
            grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
            return torch.norm(grad, dim=-1)  # [B, N]
        return torch.zeros_like(x[:, :, 0])

    def forward(self, x, cross_x=None, loss=None, noise_level=0.0):
        """Eq.9-12: Token pruning with Gumbel-Sigmoid + STE"""
        B, N, d = x.shape

        # Step 1: Compute 3-component score (Eq.9)
        A_t = self.cross_modal_agreement(x, cross_x) if cross_x is not None else torch.zeros(B, N, device=x.device)
        G_t = self.gradient_importance(x, loss)
        E_t = spectral_flatness(x)
        token_scores = self.lambda_A * A_t + self.lambda_G * G_t - self.lambda_E * E_t

        # Step 2: Adaptive threshold (Eq.10: θ based on ρ and noise level)
        threshold = torch.quantile(
            token_scores, 1 - self.rho * (1 - noise_level), dim=1, keepdim=True
        )

        # Step 3: Differentiable pruning (Eq.11: Gumbel-Sigmoid)
        if self.training:
            g = torch.distributions.Gumbel(0, 1).sample(token_scores.shape).to(x.device)
            temp = 1.0 - 0.9 * (self.training_epoch / self.max_epochs)  # Anneal 1.0→0.1
            p_t = torch.sigmoid((token_scores + g - threshold) / temp)
        else:
            p_t = (token_scores >= threshold).float()

        # Step 4: Modality-level score (Eq.12: average token score)
        mod_score = token_scores.mean(dim=1)  # [B]
        x_pruned = x * p_t.unsqueeze(-1)

        return x_pruned, p_t, mod_score, token_scores

    def set_training_params(self, training_epoch, max_epochs):
        """Set temperature annealing params (called per epoch)"""
        self.training_epoch = training_epoch
        self.max_epochs = max_epochs


# ------------------------------
# 4. Scalable Multi-Granular Alignment (§3.2: Pixel→Patch→Global)
# ------------------------------
class ScalableMultiGranularAlignment(nn.Module):
    def __init__(self, img_size=224, patch_sizes=[16, 32], d_model=768):
        super().__init__()
        self.img_size = img_size
        self.patch_sizes = patch_sizes
        self.d_model = d_model

        # Eq.1: Pixel-text projection (W_v, W_t)
        self.pixel_proj = nn.Linear(6, d_model)  # 3 (conv) + 3 (Sobel) → d_model
        self.text_proj = nn.Linear(768, d_model)  # BERT → d_model

        # Modules
        self.ot_solver = SinkhornSolver()
        self.hgr_calculator = HGRApproximator(d=d_model)
        # Global Transformer (§3.2.3: 3 layers, 8 heads)
        self.global_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model, nhead=8, dim_feedforward=4 * d_model,
                dropout=0.1, batch_first=True
            ),
            num_layers=3
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

    def forward_pixel_alignment(self, img_features, text_tokens):
        """Eq.1-3: Pixel-level OT alignment"""
        B, C, H, W = img_features.shape  # C=6 (3 conv + 3 Sobel)
        N_p = H * W

        # Step 1: Project to shared dimension (Eq.1)
        pixel_feat = img_features.permute(0, 2, 3, 1).reshape(B, N_p, C)  # [B, N_p, 6]
        pixel_feat_proj = self.pixel_proj(pixel_feat)  # [B, N_p, d_model]
        text_feat_proj = self.text_proj(text_tokens)  # [B, N_t, d_model]

        # Step 2: Window partitioning (M=32, S=8)
        window_size = 32
        stride = window_size // 4
        num_h = (H - window_size) // stride + 1
        num_w = (W - window_size) // stride + 1

        T_list = []
        C_list = []
        for i in range(num_h):
            for j in range(num_w):
                # Window coordinates
                h_start = i * stride
                h_end = h_start + window_size
                w_start = j * stride
                w_end = w_start + window_size

                # Extract window pixels
                window_pixel = pixel_feat_proj[:, h_start * W + w_start: h_end * W + w_end, :]
                window_pixel = window_pixel.reshape(B, window_size ** 2, self.d_model)

                # Cost matrix (Eq.1: 1 - cos_sim)
                cos_sim = F.cosine_similarity(
                    window_pixel.unsqueeze(2), text_feat_proj.unsqueeze(1), dim=-1
                )
                C = 1 - cos_sim  # [B, M², N_t]
                C_list.append(C)

                # Solve OT (Eq.2)
                T = self.ot_solver(C)
                T_list.append(T)

        # Step 3: Boundary-aware fusion (Eq.3)
        refined_pixel = self._window_fusion(T_list, text_feat_proj, H, W, window_size, stride)
        refined_pixel = F.layer_norm(refined_pixel, (self.d_model,))
        return refined_pixel, T_list, C_list

    def _window_fusion(self, T_list, text_feat, H, W, window_size, stride):
        """Eq.3: Fuse overlapping windows with count normalization"""
        B, N_t, d = text_feat.shape
        result = torch.zeros(B, H * W, d, device=text_feat.device)
        count = torch.zeros(B, H * W, 1, device=text_feat.device)

        idx = 0
        num_h = (H - window_size) // stride + 1
        num_w = (W - window_size) // stride + 1

        for i in range(num_h):
            for j in range(num_w):
                h_start = i * stride
                h_end = h_start + window_size
                w_start = j * stride
                w_end = w_start + window_size

                # Get window indices
                window_idx = torch.arange(h_start * W + w_start, h_end * W + w_end, device=text_feat.device)
                window_idx = window_idx.unsqueeze(0).expand(B, -1)  # [B, M²]

                # Update pixels with OT
                T = T_list[idx]  # [B, M², N_t]
                window_update = torch.matmul(T, text_feat)  # [B, M², d]

                # Scatter to result
                result = result.scatter_add(1, window_idx.unsqueeze(-1).expand(-1, -1, d), window_update)
                count = count.scatter_add(1, window_idx.unsqueeze(-1), torch.ones_like(window_update[:, :, :1]))
                idx += 1

        return result / count.clamp(min=1e-8)

    def forward_patch_alignment(self, refined_pixel, text_attr, H, W):
        """Eq.4-7: Patch-attribute HGR alignment"""
        B, N_p, d = refined_pixel.shape

        # Step 1: Extract multi-scale patches (Eq.4)
        pixel_reshaped = refined_pixel.reshape(B, H, W, d).permute(0, 3, 1, 2)  # [B, d, H, W]
        patches_list = []
        for ps in self.patch_sizes:
            # Strided conv for patch extraction
            patch_conv = nn.Conv2d(d, d, kernel_size=ps, stride=ps, padding=ps // 2).to(pixel_reshaped.device)
            patches = patch_conv(pixel_reshaped)  # [B, d, H/ps, W/ps]
            patches = patches.flatten(2).transpose(1, 2)  # [B, N_patch, d]
            patches_list.append(patches)
        patches = torch.cat(patches_list, dim=1)  # [B, N_total, d]

        # Step 2: HGR correlation (Eq.6)
        hgr_scores = self.hgr_calculator(patches, text_attr)  # [B, N_patch, N_attr]

        # Step 3: Patch refinement (Eq.7: attention weighting)
        attn_weights = F.softmax(hgr_scores / 0.1, dim=-1)  # τ=0.1
        refined_patch = patches + torch.matmul(attn_weights, text_attr)
        refined_patch = F.layer_norm(refined_patch, (d,))
        return refined_patch, hgr_scores

    def forward_global_alignment(self, refined_patch, text_attr):
        """Eq.8: Global embedding alignment"""
        B = refined_patch.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, d]

        # Transformer encoding
        img_feat = torch.cat([cls_tokens, refined_patch], dim=1)  # [B, N+1, d]
        txt_feat = torch.cat([cls_tokens, text_attr], dim=1)  # [B, M+1, d]
        img_global = self.global_transformer(img_feat)[:, 0, :]  # [B, d]
        txt_global = self.global_transformer(txt_feat)[:, 0, :]  # [B, d]

        # Global HGR loss (Eq.8)
        global_hgr = self.hgr_calculator(img_global.unsqueeze(1), txt_global.unsqueeze(1)).squeeze()
        return normalize(img_global), normalize(txt_global), global_hgr


# ------------------------------
# 5. Composite Loss (Eq.13: Full Loss Function)
# ------------------------------
class UnifiedInfoLoss(nn.Module):
    def __init__(self, lambda_patch=0.4, lambda_glob=0.6, lambda_cont=1.0, lambda_sparse=0.1, tau_cont=0.07):
        super().__init__()
        # Loss weights (§3.4: optimized via validation)
        self.lambda_patch = lambda_patch
        self.lambda_glob = lambda_glob
        self.lambda_cont = lambda_cont
        self.lambda_sparse = lambda_sparse
        self.tau_cont = tau_cont  # Contrastive temperature (CLIP-style)

    def forward(self, T_list, C_list, patch_hgr, global_hgr, img_global, txt_global, token_scores_list):
        """Eq.13: Total loss = L_pixel + λ_patch L_patch + λ_glob L_glob + λ_cont L_cont + λ_sparse L_sparse"""
        B = img_global.shape[0]

        # 1. L_pixel: OT cost (Eq. in §3.4)
        pixel_loss = 0.0
        for T, C in zip(T_list, C_list):
            pixel_loss += torch.mean(torch.sum(T * C, dim=(1, 2)))
        pixel_loss /= len(T_list) if T_list else 1.0

        # 2. L_patch: Maximize patch HGR (Eq. in §3.4)
        patch_loss = -torch.mean(patch_hgr) if patch_hgr is not None else 0.0

        # 3. L_glob: Maximize global HGR (Eq.8)
        global_loss = -torch.mean(global_hgr) if global_hgr is not None else 0.0

        # 4. L_cont: CLIP-style contrastive loss (Eq. in §3.4)
        logits = torch.matmul(img_global, txt_global.t()) / self.tau_cont
        labels = torch.arange(B, device=img_global.device)
        cont_loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) / 2

        # 5. L_sparse: Sparsity regularization (Eq. in §3.3.3)
        sparse_loss = 0.0
        for scores in token_scores_list:
            sparse_loss += torch.mean(scores)
        sparse_loss /= len(token_scores_list) if token_scores_list else 1.0

        # Total loss
        total_loss = (pixel_loss +
                      self.lambda_patch * patch_loss +
                      self.lambda_glob * global_loss +
                      self.lambda_cont * cont_loss +
                      self.lambda_sparse * sparse_loss)

        return total_loss, {
            'pixel_loss': pixel_loss.item(),
            'patch_loss': patch_loss.item(),
            'global_loss': global_loss.item(),
            'cont_loss': cont_loss.item(),
            'sparse_loss': sparse_loss.item()
        }


# ------------------------------
# 6. InfoCLIP++ Main Model (§3: Full Architecture)
# ------------------------------
class InfoCLIPPP(nn.Module):
    def __init__(self, img_size=224, d_model=768, patch_sizes=[16, 32], max_epochs=30):
        super().__init__()
        self.img_size = img_size
        self.d_model = d_model
        self.max_epochs = max_epochs  # For DAR temperature annealing

        # Backbones (§3.1)
        # DINOv2 for conv features (frozen except last 20 layers)
        self.dino = vit_large(pretrained=True)
        for param in list(self.dino.parameters())[:-20]:
            param.requires_grad = False

        # BERT for text features
        self.text_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')

        # Core modules
        self.multi_align = ScalableMultiGranularAlignment(img_size, patch_sizes, d_model)
        self.dar = DualAdaptiveRouter()
        self.loss_fn = UnifiedInfoLoss()

        # Text attribute extractor (§3.2.2: noun phrases + adjectives)
        self.attr_extractor = nn.Linear(768, d_model)

    def extract_text_features(self, texts):
        """Extract text tokens, attributes, and projections (§3.2.2)"""
        # Tokenize
        inputs = self.text_tokenizer(
            texts, return_tensors='pt', padding=True, truncation=True, max_length=64
        )
        for k, v in inputs.items():
            if v.device != next(self.parameters()).device:
                inputs[k] = v.to(next(self.parameters()).device)

        # BERT features
        outputs = self.text_encoder(**inputs)
        token_feat = outputs.last_hidden_state  # [B, N_t, 768]

        # Extract attributes (simplified: use all tokens; paper uses dependency parsing)
        attr_feat = self.attr_extractor(token_feat)  # [B, N_attr, d_model]

        return token_feat, attr_feat, inputs

    def forward(self, images, texts, noise_level=0.0, training_epoch=0):
        """Full forward pass (§3: Pixel→Patch→Global + DAR + Loss)"""
        B = images.shape[0]

        # Step 1: Extract base features
        # Image: DINO conv features + Sobel features (Eq.1)
        dino_feat = self.dino.forward_features(images)['x_norm_patchtokens']  # [B, N_p, 768]
        sobel_feat = torch.tensor(sobel(images.numpy(), axis=(2, 3)), dtype=torch.float32, device=images.device)
        sobel_feat = sobel_feat.permute(0, 1, 3, 2).reshape(B, 3, self.img_size, self.img_size)
        img_feat = torch.cat([dino_feat, sobel_feat], dim=1)  # [B, 6, H, W]

        # Text: tokens + attributes
        text_token_feat, text_attr_feat, _ = self.extract_text_features(texts)

        # Step 2: Multi-granular alignment + DAR
        # Pixel-level: OT → DAR
        refined_pixel, T_list, C_list = self.multi_align.forward_pixel_alignment(img_feat, text_token_feat)
        self.dar.set_training_params(training_epoch, self.max_epochs)
        refined_pixel, _, _, pixel_scores = self.dar(
            refined_pixel, cross_x=self.multi_align.text_proj(text_token_feat), loss=None, noise_level=noise_level
        )

        # Patch-level: HGR → DAR
        H, W = self.img_size, self.img_size
        refined_patch, patch_hgr = self.multi_align.forward_patch_alignment(refined_pixel, text_attr_feat, H, W)
        refined_patch, _, _, patch_scores = self.dar(
            refined_patch, cross_x=text_attr_feat, loss=None, noise_level=noise_level
        )

        # Global-level: HGR
        img_global, txt_global, global_hgr = self.multi_align.forward_global_alignment(refined_patch, text_attr_feat)

        # Step 3: Compute loss (training only)
        if self.training:
            # Recompute DAR with loss gradient (§3.3.1)
            token_scores_list = [pixel_scores, patch_scores]
            total_loss, loss_dict = self.loss_fn(
                T_list, C_list, patch_hgr, global_hgr, img_global, txt_global, token_scores_list
            )
            # Update DAR scores with loss gradient
            _, _, _, pixel_scores = self.dar(
                refined_pixel, cross_x=self.multi_align.text_proj(text_token_feat), loss=total_loss,
                noise_level=noise_level
            )
            _, _, _, patch_scores = self.dar(
                refined_patch, cross_x=text_attr_feat, loss=total_loss, noise_level=noise_level
            )
            # Recompute loss with updated scores
            total_loss, loss_dict = self.loss_fn(
                T_list, C_list, patch_hgr, global_hgr, img_global, txt_global, [pixel_scores, patch_scores]
            )
            return img_global, txt_global, total_loss, loss_dict
        else:
            return img_global, txt_global