
import json
import os
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None  # Allow loading large images without size limit

# --- Configuration ---
MULTI_HAYSTACK_FILE = "Multihaystack.json"  # Input dataset file
DATASET_DIR = "../dataset/dataset_full/"  # Directory containing images to be retrieved
OUTPUT_FILE = "nomic_results.json"  # Output file for results
VISION_MODEL_NAME = "nomic-ai/nomic-embed-vision-v1.5"  # Vision embedding model
TEXT_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"  # Text embedding model
EXTENSIONS_TO_STRIP = ('.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt')  # Extensions to remove when normalizing filenames
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # Use GPU if available
print(f"Using device: {DEVICE}")

# --- Model Initialization ---
print("Loading models")
# Vision model for image embedding
image_processor = AutoImageProcessor.from_pretrained(VISION_MODEL_NAME)
vision_model = AutoModel.from_pretrained(VISION_MODEL_NAME, trust_remote_code=True).to(DEVICE).eval()

# Text model for query embedding
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True).to(DEVICE).eval()
print("Models loaded")

# --- Helper Functions ---

def mean_pooling(model_output, attention_mask):
    """
    Apply mean pooling to model output using attention mask.
    This produces a fixed-size embedding vector for each input sequence.
    
    Args:
        model_output: Output from the transformer model
        attention_mask: Mask indicating which tokens to include (1) or exclude (0)
        
    Returns:
        Pooled embeddings of shape (batch_size, embedding_dim)
    """
    token_embeddings = model_output[0]  # First element contains token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def embed_texts(texts, batch_size=32):
    """
    Embed a list of text queries using the Nomic text model.
    
    Args:
        texts: List of text strings to embed
        batch_size: Number of texts to process at once
        
    Returns:
        Tensor of normalized embeddings with shape (len(texts), embedding_dim)
    """
    all_embeddings = []
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Embedding texts"):
            batch_texts = texts[i:i+batch_size]
            # Add "search_query: " prefix to improve retrieval performance
            prefixed_texts = [f"search_query: {text}" for text in batch_texts]
            
            encoded_input = tokenizer(prefixed_texts, padding=True, truncation=True, return_tensors='pt').to(DEVICE)
            model_output = text_model(**encoded_input)
            
            batch_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
            batch_embeddings = F.layer_norm(batch_embeddings, normalized_shape=(batch_embeddings.shape[1],))
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)  # L2 normalization for cosine similarity
            all_embeddings.append(batch_embeddings.cpu())
    return torch.cat(all_embeddings, dim=0) if all_embeddings else torch.empty(0)


def embed_images(image_paths, batch_size=32):
    """
    Embed a list of images using the Nomic vision model.

    """
    all_embeddings = []
    processed_image_paths = []
    with torch.no_grad():
        for i in tqdm(range(0, len(image_paths), batch_size), desc="Embedding images"):
            batch_paths = image_paths[i:i+batch_size]
            pil_images = []
            current_batch_valid_paths = []
            for img_path in batch_paths:
                try:
                    img = Image.open(img_path).convert("RGB")
                    pil_images.append(img)
                    current_batch_valid_paths.append(img_path)
                except FileNotFoundError:
                    print(f"Warning: Image file not found {img_path}, skipping.")
                except UnidentifiedImageError:
                    print(f"Warning: Cannot identify image file {img_path}, skipping.")
                except Exception as e:
                    print(f"Warning: Error loading image {img_path}: {e}, skipping.")
            
            if not pil_images:
                continue

            inputs = image_processor(images=pil_images, return_tensors="pt").to(DEVICE)
            img_emb = vision_model(**inputs).last_hidden_state
            # Use the embedding of the first token ([CLS] equivalent) as the image representation
            batch_img_embeddings = F.normalize(img_emb[:, 0], p=2, dim=1)  # L2 normalization
            all_embeddings.append(batch_img_embeddings.cpu())
            processed_image_paths.extend(current_batch_valid_paths)

    if not all_embeddings:
        return torch.empty(0,0), []  # Return empty tensor and empty list if no images processed
        
    return torch.cat(all_embeddings, dim=0), processed_image_paths


def list_dataset_files(dataset_dir):
    """
    Scan the dataset directory for valid image files and create metadata records.

    """
    items = []
    valid_image_extensions = ('.jpg', '.jpeg', '.png')
    abs_dataset_dir = os.path.abspath(dataset_dir)

    if not os.path.isdir(dataset_dir):
        print(f"Error: Dataset directory {dataset_dir} not found.")
        return items

    for root, _, files in os.walk(dataset_dir):
        for file_name in files:
            if not file_name.lower().endswith(valid_image_extensions):
                continue 

            file_path = os.path.join(root, file_name)
            abs_file_path = os.path.abspath(file_path)
            parent_dir = os.path.dirname(abs_file_path)
            
            # Use parent directory name as retrieval key if file is in a subdirectory,
            # otherwise use the filename itself
            if parent_dir == abs_dataset_dir:
                retrieval_key = file_name
            else:
                retrieval_key = os.path.basename(parent_dir)
            
            items.append({'file_path': file_path, 'retrieval_key': retrieval_key})
    return items

def normalize_name(filename):
    """
    Normalize filenames by removing specified extensions.
    This helps with matching keys during evaluation.

    """
    if not filename: return ""
    name, ext = os.path.splitext(filename)
    if ext.lower() in EXTENSIONS_TO_STRIP:
        return name
    return filename

# --- Main Logic ---

# 1. Load MultiHaystack_noaudio data
print(f"Loading {MULTI_HAYSTACK_FILE}")
try:
    with open(MULTI_HAYSTACK_FILE, 'r', encoding='utf-8') as f:
        multi_haystack_data = json.load(f)
except FileNotFoundError:
    print(f"Error: {MULTI_HAYSTACK_FILE} not found.")
    exit()
except json.JSONDecodeError:
    print(f"Error: Could not decode JSON from {MULTI_HAYSTACK_FILE}.")
    exit()
print(f"Loaded {len(multi_haystack_data)} conversations.")

# 2. Prepare dataset items and embed all of them
print(f"Scanning dataset directory: {DATASET_DIR}")
dataset_items_info = list_dataset_files(DATASET_DIR)

if not dataset_items_info:
    print("No image files found in the dataset directory. Exiting.")
    exit()

print(f"Found {len(dataset_items_info)} image items in dataset.")
all_image_paths_for_embedding = [item['file_path'] for item in dataset_items_info]
dataset_embeddings, processed_paths = embed_images(all_image_paths_for_embedding)

# Filter dataset_items_info to only include successfully processed images
valid_dataset_items_info = [item for item in dataset_items_info if item['file_path'] in processed_paths]
# Create a map from processed_path to index to ensure embeddings align with items
path_to_idx_map = {path: i for i, path in enumerate(processed_paths)}
# Re-align dataset_items_info based on actual processed_paths order
aligned_dataset_items_info = []
for item in dataset_items_info:
    if item['file_path'] in path_to_idx_map:
        aligned_dataset_items_info.append(item)

if dataset_embeddings.nelement() == 0:  # Check if tensor is empty
    print("No images were successfully embedded from the dataset. Exiting.")
    exit()

print(f"Successfully embedded {dataset_embeddings.shape[0]} images from dataset.")
dataset_embeddings = dataset_embeddings.to(DEVICE)  # Move embeddings to device for similarity calculation

# 3. Process each conversation in the MultiHaystack_noaudio dataset
results_data = []
questions_to_embed = []
original_data_indices = []  # To map back results to original conversations

for idx, item in enumerate(multi_haystack_data):
    question = None
    answer = None
    for conv in item.get("conversations", []):
        if conv.get("from") == "human":
            question = conv.get("value")
        elif conv.get("from") == "gpt":
            answer = conv.get("value")
    
    if question:
        questions_to_embed.append(question)
        original_data_indices.append(idx)  # Store index to map back later

if not questions_to_embed:
    print("No questions found in MultiHaystack_noaudio.json. Exiting.")
    exit()

print(f"Embedding {len(questions_to_embed)} questions")
question_embeddings = embed_texts(questions_to_embed).to(DEVICE)

print("Calculating similarities and retrieving top 5")
for i, original_idx in tqdm(enumerate(original_data_indices), total=len(original_data_indices), desc="Processing questions"):
    item = multi_haystack_data[original_idx]
    question = questions_to_embed[i]  # Get the question corresponding to current embedding
    
    current_question_embedding = question_embeddings[i].unsqueeze(0)  # Shape: (1, D)
    
    # Calculate cosine similarities between question and all images
    # Matrix multiplication: (1, D) @ (D, N_images) -> (1, N_images)
    similarities = torch.matmul(current_question_embedding, dataset_embeddings.T).squeeze(0)  # Shape: (N_images)
    
    # Get top N indices - we need more than 5 initially to handle unique retrieval keys
    num_to_retrieve_initially = min(len(aligned_dataset_items_info), max(5, len(aligned_dataset_items_info) // 2))  # Heuristic
    
    top_k_scores, top_k_indices = torch.topk(similarities, k=min(num_to_retrieve_initially, similarities.shape[0]), largest=True)
    
    # Extract top 5 unique retrieval keys
    top_5_retrieved_keys = []
    seen_keys = set()
    for idx in top_k_indices.tolist():
        if len(top_5_retrieved_keys) >= 5:
            break
        retrieval_key = aligned_dataset_items_info[idx]['retrieval_key']
        if retrieval_key not in seen_keys:
            top_5_retrieved_keys.append(retrieval_key)
            seen_keys.add(retrieval_key)
            
    positive_value = item.get("positive", [])
    # Handle positive_value format variations
    if isinstance(positive_value, list) and positive_value:
        positive_value = positive_value[0]
    elif isinstance(positive_value, list) and not positive_value:
        positive_value = ""  # Handle empty list case

    # Get the answer from the conversation
    answer = None
    for conv in item.get("conversations", []):
        if conv.get("from") == "gpt":
            answer = conv.get("value")
            break
            
    results_data.append({
        "question": question,
        "positive": positive_value,
        "top_5_retrieved": top_5_retrieved_keys,
        "answer": answer if answer is not None else ""
    })

# 4. Calculate recall metrics
hits_at_1 = 0
hits_at_3 = 0
hits_at_5 = 0
total_valid_questions = 0

for res_item in results_data:
    if not res_item["positive"]:  # Skip if positive is empty or None
        continue
    
    total_valid_questions += 1
    normalized_positive = normalize_name(res_item['positive'])
    
    normalized_retrieved_list = [normalize_name(r) for r in res_item['top_5_retrieved']]
    
    # Recall@1: Is the correct answer the first match?
    if len(normalized_retrieved_list) >= 1 and normalized_positive == normalized_retrieved_list[0]:
        hits_at_1 += 1
    
    # Recall@3: Is the correct answer in the top 3 matches?
    if normalized_positive in normalized_retrieved_list[:3]:
        hits_at_3 += 1
        
    # Recall@5: Is the correct answer in the top 5 matches?
    if normalized_positive in normalized_retrieved_list[:5]:
        hits_at_5 += 1

recall_at_1 = (hits_at_1 / total_valid_questions) if total_valid_questions > 0 else 0
recall_at_3 = (hits_at_3 / total_valid_questions) if total_valid_questions > 0 else 0
recall_at_5 = (hits_at_5 / total_valid_questions) if total_valid_questions > 0 else 0

metrics = {
    "recall@1": recall_at_1,
    "recall@3": recall_at_3,
    "recall@5": recall_at_5,
    "total_evaluated_questions": total_valid_questions
}

print("Metrics:", metrics)

# 5. Save results to output file
final_output = {
    "metrics": metrics,
    "results": results_data
}

print(f"Saving results to {OUTPUT_FILE}")
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(final_output, f, indent=2, ensure_ascii=False)

print("Processing complete.")
