import os
import json
import torch
from PIL import Image
import numpy as np
from transformers import AutoModel
import gc
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
import re
import warnings
warnings.filterwarnings("ignore", message="You may have used the wrong order for inputs")

# Configuration and model loading
MODEL_NAME = "nvidia/MM-Embed"
DATASET_PATH = "../dataset/dataset_full/"
INPUT_JSON = "Multihaystack.json"
OUTPUT_JSON = "mmembed_results.json"

# Load model (MM-Embed without separate processor)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading {MODEL_NAME} on {device}")

try:
    model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = model.cuda() if torch.cuda.is_available() else model.to(device)
    model.eval()
    print("Model loaded")
    
    # Test text encoding to verify model loading
    print("Testing text encoding")
    test_text_passage = [{'txt': 'This is a simple test text.'}]
    
    with torch.no_grad():
        test_outputs = model.encode(test_text_passage, max_length=512)
    print(f"Text encoding validated: {test_outputs['hidden_states'].shape}")
    
    # Test image encoding
    print("Testing image encoding")
    test_image = Image.new('RGB', (336, 336), color='red')
    test_image_passage = [{'img': test_image}]
    
    try:
        with torch.no_grad():
            test_outputs = model.encode(test_image_passage, max_length=512)
        print(f"Image encoding validated: {test_outputs['hidden_states'].shape}")
    except Exception as e:
        print(f"Image encoding failed: {e}")
        print("Proceeding with text-only processing")
    
except Exception as e:
    print(f"Error during model loading or testing: {e}")
    import traceback
    traceback.print_exc()
    exit(1)

# Define task instructions
TASK_INSTRUCTION = "Retrieve a relevant document or image that matches the given query."

# Load evaluation dataset
with open(INPUT_JSON, "r", encoding="utf-8") as f:
    haystack_data = json.load(f)

# Collect valid files from dataset
def collect_files(directory):
    """Collect all jpg, png, jpeg, txt files in the directory"""
    valid_files = []
    file_mapping = {}
    
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.png', '.jpeg', '.txt')):
                full_path = os.path.join(root, file)
                
                # Determine report path for file
                if root == directory:
                    # File directly in dataset
                    report_path = file
                else:
                    # File in subfolder
                    subdir = os.path.basename(root)
                    report_path = subdir
                
                valid_files.append(full_path)
                file_mapping[full_path] = report_path
    
    return valid_files, file_mapping

valid_files, file_mapping = collect_files(DATASET_PATH)
print(f"Found {len(valid_files)} valid files")

# Process files and generate embeddings
def get_file_embedding(file_path):
    """Get embedding vectors based on file type"""
    max_length = 4096  # MM-Embed recommended max length
    
    if file_path.lower().endswith(('.jpg', '.png', '.jpeg')):
        # Process image files
        try:
            image = Image.open(file_path).convert('RGB')
            
            # Resize image to avoid excessive size
            width, height = image.size
            max_dimension = 1024
            
            if width > max_dimension or height > max_dimension:
                # Calculate scaling ratio maintaining aspect ratio
                scale = min(max_dimension / width, max_dimension / height)
                new_width = int(width * scale)
                new_height = int(height * scale)
                
                # Ensure reasonable dimensions
                new_width = max(new_width, 224)
                new_height = max(new_height, 224)
                
                image = image.resize((new_width, new_height), Image.LANCZOS)
            
            # MM-Embed uses unified encode interface with dict wrapper
            passage = [{'img': image}]
            with torch.no_grad():
                outputs = model.encode(passage, max_length=max_length)
            return outputs['hidden_states'].cpu().numpy()[0]
        except Exception as e:
            print(f"Error processing image file {file_path}: {e}")
            return None
    elif file_path.lower().endswith('.txt'):
        # Process text files
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()
            # Limit text length to avoid excessive length
            if len(text) > 8000:
                text = text[:8000]
            # MM-Embed uses unified encode interface with dict wrapper
            passage = [{'txt': text}]
            with torch.no_grad():
                outputs = model.encode(passage, max_length=max_length)
            return outputs['hidden_states'].cpu().numpy()[0]
        except Exception as e:
            print(f"Error processing text file {file_path}: {e}")
            return None
    return None

def clear_gpu_memory():
    """Clear GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

# Batch process all files
file_embeddings = {}
print("Processing dataset files")
for i, file_path in enumerate(tqdm(valid_files)):
    embedding = get_file_embedding(file_path)
    if embedding is not None:
        file_embeddings[file_path] = embedding
    
    # Clean GPU memory every 10 files
    if (i + 1) % 10 == 0:
        clear_gpu_memory()

print(f"Processed {len(file_embeddings)} files")

# Process questions and retrieve similar files
def get_question_embedding(question):
    """Generate text embedding for question."""
    max_length = 4096  # MM-Embed recommended max length
    
    # MM-Embed queries require is_query=True and instruction
    query = [{'txt': question}]
    with torch.no_grad():
        outputs = model.encode(query, is_query=True, instruction=TASK_INSTRUCTION, max_length=max_length)
    return outputs['hidden_states'].cpu().numpy()[0]

def cosine_similarity(v1, v2):
    """Calculate cosine similarity (MM-Embed output is normalized)."""
    # Normalized output allows direct dot product
    return np.dot(v1, v2)

def get_top_k_similar(question_embedding, k=5):
    """Get top k most similar files, one per subfolder."""
    similarities = []
    for file_path, file_embedding in file_embeddings.items():
        similarity = cosine_similarity(question_embedding, file_embedding)
        similarities.append((file_path, similarity))
    
    # Sort by similarity descending
    similarities.sort(key=lambda x: x[1], reverse=True)
    
    # Ensure each subfolder appears only once
    top_k_results = []
    seen_folders = set()
    
    for file_path, similarity in similarities:
        # Get report path
        report_path = file_mapping[file_path]
        
        # Check for duplicate subfolders
        if report_path in seen_folders:
            continue
        
        # Add to results
        top_k_results.append((file_path, similarity))
        seen_folders.add(report_path)
        
        # Stop when limit reached
        if len(top_k_results) >= k:
            break
    
    return top_k_results

# Calculate recall and build results
def basename_without_ext(filename):
    """Extract basename without extension."""
    known_extensions = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']
    base_name = os.path.basename(filename)
    
    for ext in known_extensions:
        if base_name.lower().endswith(ext):
            return base_name[:-len(ext)]
    
    return base_name

results = []
correct_at_1 = 0
correct_at_3 = 0
correct_at_5 = 0
total = 0

print("Processing questions and calculating similarity")
for item in tqdm(haystack_data):
    # Extract question
    question = ""
    answer = ""
    for conv in item["conversations"]:
        if conv["from"] == "human":
            question = conv["value"]
        elif conv["from"] == "gpt":
            answer = conv["value"]
    
    # Extract correct answer
    positive_files = item["positive"]
    # Remove extensions for comparison
    positive_base_names = [basename_without_ext(p) for p in positive_files]
    
    # Get question embedding
    question_embedding = get_question_embedding(question)
    
    # Get top 5 similar files
    top_5 = get_top_k_similar(question_embedding, 5)
    
    # Prepare result entry
    reported_top_5 = []
    found_at_1 = False
    found_at_3 = False
    found_at_5 = False
    
    for i, (file_path, similarity) in enumerate(top_5):
        # Get path to report
        report_path = file_mapping[file_path]
        reported_top_5.append(report_path)
        
        # Check if contains correct answer
        report_base_name = basename_without_ext(report_path)
        for pos_name in positive_base_names:
            if report_base_name == pos_name:
                if i == 0:
                    found_at_1 = True
                    found_at_3 = True
                    found_at_5 = True
                elif i < 3:
                    found_at_3 = True
                    found_at_5 = True
                else:
                    found_at_5 = True
    
    # Update counts
    if found_at_1:
        correct_at_1 += 1
    if found_at_3:
        correct_at_3 += 1
    if found_at_5:
        correct_at_5 += 1
    total += 1
    
    # Add to results list
    result_item = {
        "question": question,
        "positive": positive_files,
        "top_5_retrieved": reported_top_5,
        "answer": answer
    }
    results.append(result_item)

# Calculate recall metrics
recall_at_1 = correct_at_1 / total if total > 0 else 0
recall_at_3 = correct_at_3 / total if total > 0 else 0
recall_at_5 = correct_at_5 / total if total > 0 else 0

# Build final results and save
final_result = {
    "metrics": {
        "recall@1": recall_at_1,
        "recall@3": recall_at_3,
        "recall@5": recall_at_5
    },
    "results": results
}

with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
    json.dump(final_result, f, ensure_ascii=False, indent=2)

print(f"Processing complete! Results saved to {OUTPUT_JSON}")
print(f"Processed {len(file_embeddings)} files and {total} questions")
print(f"Recall@1: {recall_at_1:.4f}")
print(f"Recall@3: {recall_at_3:.4f}")
print(f"Recall@5: {recall_at_5:.4f}")