import torch
import torch.nn.functional as F
import clip # Assuming the open_clip package is installed

class NearestNeighborSearch:
    """
    A helper class to find the nearest neighbor in a pre-computed
    set of training data embeddings.
    """
    def __init__(self, model, device='cuda'):
        self.device = device
        # For this example, we use the CLIP model itself for embeddings.
        # The official AMG implementation uses a more efficient, pre-computed index.
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
        self.training_embeds = None

    def precompute_training_embeddings(self, training_images):
        """Encodes all training images and stores their embeddings."""
        print("Pre-computing embeddings for training data...")
        with torch.no_grad():
            image_inputs = torch.stack([self.clip_preprocess(img) for img in training_images]).to(self.device)
            self.training_embeds = self.clip_model.encode_image(image_inputs)
            self.training_embeds /= self.training_embeds.norm(dim=-1, keepdim=True)
        print(f"Stored {len(self.training_embeds)} training embeddings.")

    def find_nearest_neighbor(self, image_tensor):
        """Finds the nearest neighbor for a given image tensor."""
        if self.training_embeds is None:
            raise RuntimeError("Training embeddings have not been pre-computed.")
        
        with torch.no_grad():
            image_input = self.clip_preprocess(image_tensor).unsqueeze(0).to(self.device)
            query_embed = self.clip_model.encode_image(image_input)
            query_embed /= query_embed.norm(dim=-1, keepdim=True)
            
            # Compute cosine similarity and find the best match
            similarities = (100.0 * query_embed @ self.training_embeds.T).softmax(dim=-1)
            best_match_score, best_match_idx = torch.max(similarities, dim=-1)
            
            return best_match_score.item(), best_match_idx.item()