import math
import torch
import torch.nn as nn

from PIL import Image
from torchvision import transforms

# Code adapted from OpenGlue, MIT license
# https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/optimal_transport.py
def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor:
    r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem.
    This function solves the optimization problem and returns the OT matrix for the given parameters.
    Args:
        log_a : torch.Tensor
            Source weights
        log_b : torch.Tensor
            Target weights
        M : torch.Tensor
            metric cost matrix
        num_iters : int, default=100
            The number of iterations.
        reg : float, default=1.0
            regularization value
    """
    M = M / reg  # regularization

    u, v = torch.zeros_like(log_a), torch.zeros_like(log_b)

    for _ in range(num_iters):
        u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze()
        v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze()

    return M + u.unsqueeze(2) + v.unsqueeze(1)

# Code adapted from OpenGlue, MIT license
# https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/superglue.py
def get_matching_probs(S, dustbin_score = 1.0, num_iters=3, reg=1.0):
    """sinkhorn"""
    batch_size, m, n = S.size()
    # augment scores matrix
    S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device)
    S_aug[:, :m, :n] = S
    S_aug[:, m, :] = dustbin_score

    # prepare normalized source and target log-weights
    norm = -torch.tensor(math.log(n + m), device=S.device)
    log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous()
    log_a[-1] = log_a[-1] + math.log(n-m)
    log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1)
    log_P = log_otp_solver(
        log_a,
        log_b,
        S_aug,
        num_iters=num_iters,
        reg=reg
    )
    return log_P - norm


class SALAD(nn.Module):
    """
    This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model.

    Attributes:
        num_channels (int): The number of channels of the inputs (d).
        num_clusters (int): The number of clusters in the model (m).
        cluster_dim (int): The number of channels of the clusters (l).
        token_dim (int): The dimension of the global scene token (g).
        dropout (float): The dropout rate.
    """
    def __init__(self,
            num_channels=768,
            num_clusters=64,
            cluster_dim=128,
            token_dim=256,
            dropout=0.3,
        ) -> None:
        super().__init__()

        self.num_channels = num_channels
        self.num_clusters= num_clusters
        self.cluster_dim = cluster_dim
        self.token_dim = token_dim
        
        if dropout > 0:
            dropout = nn.Dropout(dropout)
        else:
            dropout = nn.Identity()

        # MLP for global scene token g
        self.token_features = nn.Sequential(
            nn.Linear(self.num_channels, 512),
            nn.ReLU(),
            nn.Linear(512, self.token_dim)
        )
        # MLP for local features f_i
        self.cluster_features = nn.Sequential(
            nn.Conv2d(self.num_channels, 512, 1),
            dropout,
            nn.ReLU(),
            nn.Conv2d(512, self.cluster_dim, 1)
        )
        # MLP for score matrix S
        self.score = nn.Sequential(
            nn.Conv2d(self.num_channels, 512, 1),
            dropout,
            nn.ReLU(),
            nn.Conv2d(512, self.num_clusters, 1),
        )
        # Dustbin parameter z
        self.dust_bin = nn.Parameter(torch.tensor(1.))


    def forward(self, x):
        """
        x (tuple): A tuple containing two elements, f and t. 
            (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14].
            (torch.Tensor): The token tensor (t_{n+1}) [B, C].

        Returns:
            f (torch.Tensor): The global descriptor [B, m*l + g]
        """
        x, t = x # Extract features and token                   # x feature tensor : torch.Size([240, 768, 16, 16])     t token tensor : torch.Size([240, 768])

        f = self.cluster_features(x).flatten(2)                     # torch.Size([240, 128, 256])
        p = self.score(x).flatten(2)                                        # torch.Size([240, 64, 256])
        t = self.token_features(t)                                                                                          # torch.Size([240, 256])

        # Sinkhorn algorithm
        p = get_matching_probs(p, self.dust_bin, 3)                         # torch.Size([240, 65, 256])
        p = torch.exp(p)
        # Normalize to maintain mass
        p = p[:, :-1, :]                                                    # torch.Size([240, 64, 256])


        p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1)                # torch.Size([240, 128, 64, 256])
        f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)       # torch.Size([240, 128, 64, 256])

        f = torch.cat([
            nn.functional.normalize(t, p=2, dim=-1),                                                                        # torch.Size([240, 256])
            nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1) # torch.Size([240, 8192])
        ], dim=-1)                                                                              # torch.Size([240, 8448])

        nf = nn.functional.normalize(f, p=2, dim=-1)                                             # torch.Size([240, 8448])
        return nf
    
    def visualize_attentionmap(self, x):
        x, t = x

        # Calculate feature embeddings f
        f = self.cluster_features(x)  # Shape: [B, cluster_dim, H', W']

        # Calculate scores (p) using the score layer
        p = self.score(x).flatten(2)  # Shape: [B, num_clusters, H' * W']
        # Apply Sinkhorn algorithm to compute matching probabilities
        p = get_matching_probs(p, self.dust_bin, 3)  # Shape: [B, num_clusters + 1, H' * W']
        p = torch.exp(p)
        p = p[:, :-1, :]  # Remove dustbin probabilities: [B, num_clusters, H' * W']

        # Reshape p to match spatial dimensions
        p = p.view(p.size(0), p.size(1), x.size(2), x.size(3))  # Shape: [B, num_clusters, H', W']

        # Compute normalized attention map
        attention_map = nn.functional.normalize(p, p=2, dim=1)  # Normalize across clusters
        return attention_map

    # def visualize_attentionmap(self, x):
    #     from torch.nn import functional as F
    #     x, t = x


    #     # 1. Local feature importance 계산 (score를 기반으로)
    #     local_features = self.score(x)  # Shape: [B, num_clusters, H, W]
    #     local_features = F.softmax(local_features, dim=1)  # Normalize scores

    #     # 2. dustbin 제외 후 중요도 계산
    #     # Dustbin이 가장 마지막 클러스터라고 가정
    #     features_without_dustbin = local_features[:, :-1, :, :]  # Exclude dustbin

    #     # 3. 첫 번째 샘플의 중요도 합산
    #     importance = features_without_dustbin[0].sum(dim=0)  # Shape: [H, W]

    #     # 4. Normalize and Resize for visualization
    #     importance_resized = F.interpolate(
    #         importance.unsqueeze(0).unsqueeze(0), 
    #         scale_factor=14, 
    #         mode="bilinear", 
    #         align_corners=False
    #     ).squeeze()  # Shape: [H', W']

    #     # 5. 시각화
    #     heatmap = importance_resized.detach().cpu().numpy()
    #     print(heatmap.shape)
    #     exit()
    #     # 6. 0번째 샘플의 Attention Map 반환
    #     return heatmap


# for testing
if __name__ == "__main__":
    # Load the image
    # image_path = "/database/dkim/VPR_datasets/Nordland/ref/0000014.jpg"
    image_path = "/database/dkim/VPR_datasets/Nordland/query/0000001.jpg"

    image = Image.open(image_path)

    # Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    preprocessed_image = transform(image).unsqueeze(0)

    B = 4
    C = 768
    H = 16
    W = 16
    x_local = torch.randn(B, C, H, W)
    t = torch.randn(B, C)

    # Create the SALAD model
    model = SALAD()

    # Perform inference
    output = model((x_local, t))

    # Print the output
    print(output.shape)