import os
import json
import torch
from PIL import Image
import numpy as np
from transformers import AutoProcessor, AutoModel
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
import re

# Configuration and model loading
MODEL_NAME = "google/siglip2-so400m-patch14-384"
DATASET_PATH = "../dataset/dataset_full/"
INPUT_JSON = "Multihaystack.json"
OUTPUT_JSON = "siglip2_results.json"

# Load model and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
processor = AutoProcessor.from_pretrained(MODEL_NAME)

# Load evaluation dataset
with open(INPUT_JSON, "r", encoding="utf-8") as f:
    haystack_data = json.load(f)

# Collect valid files from dataset
def collect_files(directory):
    """Collect all supported files from directory."""
    valid_files = []
    file_mapping = {}
    
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.png', '.jpeg', '.txt')):
                full_path = os.path.join(root, file)
                
                # Determine report path for file
                if root == directory:
                    # File directly in dataset
                    report_path = file
                else:
                    # File in subfolder
                    subdir = os.path.basename(root)
                    report_path = subdir  # Use subfolder name only
                
                valid_files.append(full_path)
                file_mapping[full_path] = report_path
    
    return valid_files, file_mapping

valid_files, file_mapping = collect_files(DATASET_PATH)
print(f"Found {len(valid_files)} valid files")

# Process files and generate embeddings
def get_file_embedding(file_path):
    """Generate embedding vector based on file type."""
    if file_path.lower().endswith(('.jpg', '.png', '.jpeg')):
        # Process image files
        try:
            image = Image.open(file_path).convert('RGB')
            inputs = processor(images=image, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model.get_image_features(**inputs)
            return outputs.cpu().numpy()[0]
        except Exception as e:
            print(f"Error processing image file {file_path}: {e}")
            return None
    elif file_path.lower().endswith('.txt'):
        # Process text files
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()
            # Apply lowercasing and recommended padding/max_length
            inputs = processor(text=text.lower(), return_tensors="pt", padding="max_length", max_length=64, truncation=True).to(device)
            with torch.no_grad():
                outputs = model.get_text_features(**inputs)
            return outputs.cpu().numpy()[0]
        except Exception as e:
            print(f"Error processing text file {file_path}: {e}")
            return None
    return None

# Batch process all files
file_embeddings = {}
print("Processing dataset files...")
for file_path in tqdm(valid_files):
    embedding = get_file_embedding(file_path)
    if embedding is not None:
        file_embeddings[file_path] = embedding

# Process questions and retrieve similar files
def get_question_embedding(question):
    """Generate text embedding for question."""
    # Apply lowercasing and recommended padding/max_length
    inputs = processor(text=question.lower(), return_tensors="pt", padding="max_length", max_length=64, truncation=True).to(device)
    with torch.no_grad():
        outputs = model.get_text_features(**inputs)
    return outputs.cpu().numpy()[0]

def cosine_similarity(v1, v2):
    """Calculate cosine similarity."""
    dot_product = np.dot(v1, v2)
    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)
    return dot_product / (norm_v1 * norm_v2)

def get_top_k_similar(question_embedding, k=5):
    """Get top k most similar files, ensuring each subfolder appears only once."""
    similarities = []
    for file_path, file_embedding in file_embeddings.items():
        similarity = cosine_similarity(question_embedding, file_embedding)
        similarities.append((file_path, similarity))
    
    # Sort by similarity descending
    similarities.sort(key=lambda x: x[1], reverse=True)
    
    # Ensure each subfolder appears only once
    top_k_results = []
    seen_folders = set()
    
    for file_path, similarity in similarities:
        # Get report path
        report_path = file_mapping[file_path]
        
        # Check if subfolder already added
        if report_path in seen_folders:
            continue
        
        # Add to results and mark as seen
        top_k_results.append((file_path, similarity))
        seen_folders.add(report_path)
        
        # Stop when enough results found
        if len(top_k_results) >= k:
            break
    
    return top_k_results

# Calculate recall and build results
def basename_without_ext(filename):
    """Extract basename without extension."""
    known_extensions = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']
    base_name = os.path.basename(filename)
    
    for ext in known_extensions:
        if base_name.lower().endswith(ext):
            return base_name[:-len(ext)]
    
    return base_name

results = []
correct_at_1 = 0
correct_at_3 = 0
correct_at_5 = 0
total = 0

print("Processing questions and calculating similarity...")
for item in tqdm(haystack_data):
    # Extract question
    question = ""
    answer = ""
    for conv in item["conversations"]:
        if conv["from"] == "human":
            question = conv["value"]
        elif conv["from"] == "gpt":
            answer = conv["value"]
    
    # Extract correct answer
    positive_files = item["positive"]
    # Remove extensions for comparison
    positive_base_names = [basename_without_ext(p) for p in positive_files]
    
    # Get question embedding
    question_embedding = get_question_embedding(question)
    
    # Get top 5 similar files
    top_5 = get_top_k_similar(question_embedding, 5)
    
    # Prepare result entry
    reported_top_5 = []
    found_at_1 = False
    found_at_3 = False
    found_at_5 = False
    
    for i, (file_path, similarity) in enumerate(top_5):
        # Get path to report
        report_path = file_mapping[file_path]
        reported_top_5.append(report_path)
        
        # Check if contains correct answer
        report_base_name = basename_without_ext(report_path)
        for pos_name in positive_base_names:
            if report_base_name == pos_name:
                if i == 0:
                    found_at_1 = True
                    found_at_3 = True
                    found_at_5 = True
                elif i < 3:
                    found_at_3 = True
                    found_at_5 = True
                else:
                    found_at_5 = True
    
    # Update counts
    if found_at_1:
        correct_at_1 += 1
    if found_at_3:
        correct_at_3 += 1
    if found_at_5:
        correct_at_5 += 1
    total += 1
    
    # Add to results list
    result_item = {
        "question": question,
        "positive": positive_files,
        "top_5_retrieved": reported_top_5,
        "answer": answer
    }
    results.append(result_item)

# Calculate recall metrics
recall_at_1 = correct_at_1 / total if total > 0 else 0
recall_at_3 = correct_at_3 / total if total > 0 else 0
recall_at_5 = correct_at_5 / total if total > 0 else 0

# Build final results and save
final_result = {
    "metrics": {
        "recall@1": recall_at_1,
        "recall@3": recall_at_3,
        "recall@5": recall_at_5
    },
    "results": results
}

with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
    json.dump(final_result, f, ensure_ascii=False, indent=2)

print(f"Processing complete. Results saved to {OUTPUT_JSON}")
print(f"Recall@1: {recall_at_1:.4f}")
print(f"Recall@3: {recall_at_3:.4f}")
print(f"Recall@5: {recall_at_5:.4f}")