import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import KMeans
import numpy as np


class NetVLAD(nn.Module):
    """NetVLAD layer implementation"""

    def __init__(self, num_clusters=64, dim=128, alpha=100.0,
                 normalize_input=True):
        """
        Args:
            num_clusters : int
                The number of clusters
            dim : int
                Dimension of descriptors
            alpha : float
                Parameter of initialization. Larger value is harder assignment.
            normalize_input : bool
                If true, descriptor-wise L2 normalization is applied to input.
        """
        super(NetVLAD, self).__init__()
        self.num_clusters = num_clusters
        self.dim = dim
        self.alpha = alpha
        self.normalize_input = normalize_input
        self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True)
        self.centroids = nn.Parameter(torch.rand(num_clusters, dim))
        self._init_params()

    def _init_params(self):
        self.conv.weight = nn.Parameter(
            (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
        )
        self.conv.bias = nn.Parameter(
            - self.alpha * self.centroids.norm(dim=1)
        )

    def forward(self, x):
        N, C = x.shape[:2]

        if self.normalize_input:
            x = F.normalize(x, p=2, dim=1)  # across descriptor dim

        # soft-assignment
        soft_assign = self.conv(x).view(N, self.num_clusters, -1)
        soft_assign = F.softmax(soft_assign, dim=1)

        x_flatten = x.view(N, C, -1)
        
        # calculate residuals to each clusters
        # x_flatten: [N, C, H*W], centroids: [K, C]
        # Need to compute: x_flatten - centroids for each cluster
        HW = x_flatten.size(-1)
        residual = x_flatten.unsqueeze(1) - self.centroids.unsqueeze(0).unsqueeze(-1)  # [N, K, C, H*W]
        residual *= soft_assign.unsqueeze(2)  # [N, K, 1, H*W] -> [N, K, C, H*W]
        vlad = residual.sum(dim=-1)  # [N, K, C]

        vlad = F.normalize(vlad, p=2, dim=2)  # intra-normalization
        vlad = vlad.view(x.size(0), -1)  # flatten
        vlad = F.normalize(vlad, p=2, dim=1)  # L2 normalize

        return vlad

    def init_centroids_with_kmeans(self, features):
        """
        Initialize centroids using K-means clustering on provided features.
        
        Args:
            features: torch.Tensor of shape (N, C, H, W) or (N, C*H*W)
                Features to cluster for centroid initialization
        """
        if len(features.shape) == 4:
            # If features are 4D (N, C, H, W), flatten spatial dimensions
            N, C, H, W = features.shape
            features_flat = features.view(N, C, H*W).permute(0, 2, 1).contiguous()  # (N, H*W, C)
            features_flat = features_flat.view(-1, C)  # (N*H*W, C)
        else:
            # Assume features are already flattened (N, C)
            features_flat = features
        
        # Convert to numpy for sklearn
        features_numpy = features_flat.detach().cpu().numpy()
        
        # Perform K-means clustering
        kmeans = KMeans(n_clusters=self.num_clusters, random_state=42, n_init=10)
        kmeans.fit(features_numpy)
        
        # Get cluster centers and convert back to torch
        cluster_centers = torch.from_numpy(kmeans.cluster_centers_).float()
        
        # Update centroids parameter (don't normalize - let NetVLAD handle normalization)
        self.centroids.data = cluster_centers.to(self.centroids.device)
        
        # Re-initialize conv layer weights and bias based on new centroids
        self._init_params()
        
        print(f"NetVLAD centroids initialized with K-means. Centroids shape: {self.centroids.shape}")

