import os
import json
import clip
import torch
import objaverse
import numpy as np
import pandas as pd
from PIL import Image
import torchvision.transforms as T
from typing import List, Dict, Tuple, Optional


def uids_to_prompts(uids: List[str]) -> List[str]:
    """Converts a list of unique identifiers (uids) to tag-concat prompts based on their annotations."""
    annotations = objaverse.load_annotations(uids)
    prompts_data = []
    for uid in uids:
        if uid in annotations:
            annotation = annotations[uid]
            attributes = [annotation.get('name', '')]
            for tag in annotation.get('tags', []):
                attributes.append(tag.get('name', ''))
            attributes.append('3d asset')
            prompt = ', '.join(filter(None, attributes))
            prompts_data.append({
                "Caption": prompt,
                "uid": uid
            })
    # print(prompts_data)
    return prompts_data


def load_uids_from_clusters(clusters_file: str = 'data/objaverse-dupes/aggregated_clusters.json',
                            concept_key: str = 'teddy_bear') -> List[str]:
    try:
        with open(clusters_file, 'r') as f:
            clusters_data = json.load(f)
        
        concept_data = clusters_data.get(concept_key, {})
        
        all_uids = []
        if isinstance(concept_data, dict):
            for cluster_id, uid_list in sorted(concept_data.items(), key=lambda item: int(item[0])):
                if isinstance(uid_list, list):
                    all_uids.extend(sorted(uid_list))
        elif isinstance(concept_data, list):
            all_uids = sorted(concept_data)
        
        print(f"Loaded and sorted {len(all_uids)} UIDs for concept '{concept_key}' from {clusters_file}")
        return all_uids
        
    except Exception as e:
        print(f"Error occurred while loading UIDs: {e}")
        return []
    

def load_prompts_from_csv(file_path: str) -> (List[str], List[Dict]):
    """
    Args:
        file_path (str): Path to the CSV file.

    Returns:
        A tuple containing:
        - List[str]: The list of prompts.
        - List[Dict]: A corresponding list of metadata dictionaries.
    """
    if not os.path.exists(file_path):
        print(f"Warning: CSV file not found at {file_path}. Skipping.")
        return [], []

    try:
        df = pd.read_csv(file_path, sep=';')
        if "Caption" not in df.columns or "URL" not in df.columns:
            print(f"Warning: CSV {file_path} must contain 'Caption' and 'URL' columns. Skipping.")
            return [], []

        prompts = df["Caption"].tolist()
        metadata = df.apply(lambda row: {"ground_truth_url": row["URL"]}, axis=1).tolist()
        
        print(f"Loaded {len(prompts)} prompts from {file_path}")
        return prompts, metadata

    except Exception as e:
        print(f"Error loading CSV file {file_path}: {e}")
        return [], []


class NearestNeighborSearch:
    """
    Nearest neighbor search for AMG implementation
    Computes similarity between generated images and training data
    """
    
    def __init__(self, model, similarity_metric='sscd'):
        """
        Args:
            model: The diffusion model (for access to feature extractors)
            similarity_metric: 'sscd' or 'clip' for similarity computation
        """
        self.model = model
        self.similarity_metric = similarity_metric
        self.device = model.device
        
        # Initialize feature extractor based on similarity metric
        if similarity_metric == 'clip':
            self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
        elif similarity_metric == 'sscd':
            # SSCD would need to be imported separately
            # For now, we'll use CLIP as a placeholder
            self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
        
        self.training_embeddings = None
        self.training_captions = None
        
    def precompute_training_embeddings(self, training_images: List[Image.Image], 
                                     training_captions: Optional[List[str]] = None):
        """
        Precompute embeddings for training images to speed up NN search
        
        Args:
            training_images: List of PIL Images from training set
            training_captions: Optional list of captions for each image
        """
        embeddings = []
        
        with torch.no_grad():
            for img in training_images:
                if self.similarity_metric in ['clip', 'sscd']:  # Using CLIP for both for now
                    # Preprocess and encode image
                    img_tensor = self.clip_preprocess(img).unsqueeze(0).to(self.device)
                    embedding = self.clip_model.encode_image(img_tensor)
                    embedding = embedding / embedding.norm(dim=-1, keepdim=True)  # Normalize
                    embeddings.append(embedding.cpu())
        
        self.training_embeddings = torch.cat(embeddings, dim=0)
        self.training_captions = training_captions
        print(f"Precomputed {len(training_images)} training embeddings")
    
    def find_nearest_neighbor(self, query_image: Image.Image) -> Tuple[float, int, Optional[str]]:
        """
        Find nearest neighbor for a query image
        
        Args:
            query_image: PIL Image to find nearest neighbor for
            
        Returns:
            similarity_score: Similarity score with nearest neighbor
            neighbor_idx: Index of nearest neighbor in training set
            neighbor_caption: Caption of nearest neighbor (if available)
        """
        if self.training_embeddings is None:
            raise ValueError("Must call precompute_training_embeddings first")
            
        with torch.no_grad():
            # Encode query image - ensure float32 for CPU operations
            if self.similarity_metric in ['clip', 'sscd']:
                query_tensor = self.clip_preprocess(query_image).unsqueeze(0).to(self.device)
                query_embedding = self.clip_model.encode_image(query_tensor)
                query_embedding = query_embedding / query_embedding.norm(dim=-1, keepdim=True)
                
                # Ensure both tensors are float32 and on CPU for matrix multiplication
                query_embedding_cpu = query_embedding.cpu().float()
                training_embeddings_cpu = self.training_embeddings.float()
                
                # Compute similarities with all training embeddings
                similarities = torch.mm(query_embedding_cpu, training_embeddings_cpu.t())
                
                # Find most similar
                max_sim, max_idx = torch.max(similarities, dim=1)
                similarity_score = max_sim.item()
                neighbor_idx = max_idx.item()
        
        # Get neighbor caption if available
        neighbor_caption = None
        if self.training_captions is not None:
            neighbor_caption = self.training_captions[neighbor_idx]
            
        return similarity_score, neighbor_idx, neighbor_caption
    
    def compute_similarity_score(self, image1: Image.Image, image2: Image.Image) -> float:
        """
        Compute similarity between two images
        
        Args:
            image1, image2: PIL Images to compare
            
        Returns:
            similarity_score: Similarity score between images
        """
        with torch.no_grad():
            if self.similarity_metric in ['clip', 'sscd']:
                # Encode both images
                img1_tensor = self.clip_preprocess(image1).unsqueeze(0).to(self.device)
                img2_tensor = self.clip_preprocess(image2).unsqueeze(0).to(self.device)
                
                emb1 = self.clip_model.encode_image(img1_tensor)
                emb2 = self.clip_model.encode_image(img2_tensor)
                
                # Normalize and compute cosine similarity
                emb1 = emb1 / emb1.norm(dim=-1, keepdim=True)
                emb2 = emb2 / emb2.norm(dim=-1, keepdim=True)
                
                similarity = torch.mm(emb1, emb2.t()).item()
                
        return similarity
