import os
import json
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import gc
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", message=".*use_fast.*")

# Configuration
MODEL_NAME = "royokong/e5-v"
DATASET_PATH = "../dataset/dataset_full/"
INPUT_JSON = "Multihaystack.json"
OUTPUT_JSON = "e5v_results.json"
BATCH_SIZE = 1  # Process one at a time to avoid errors
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

def clear_memory():
    """Clear GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def load_e5v_model():
    """Load E5-V model and processor"""
    print("Loading E5-V model")
    
    # Load processor with use_fast=False to avoid warnings and ensure compatibility
    processor = LlavaNextProcessor.from_pretrained(MODEL_NAME, use_fast=False)
    
    # Set patch_size if it's None
    if hasattr(processor, 'patch_size') and processor.patch_size is None:
        processor.patch_size = 14  # Default patch_size for CLIP-based models
    
    model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_NAME, 
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    # Define prompt templates (from the official example)
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n'
    
    img_prompt = llama3_template.format('<image>\nSummary above image in one word: ')
    text_prompt = llama3_template.format('<sent>\nSummary above sentence in one word: ')
    
    print("Model loaded")
    return model, processor, img_prompt, text_prompt

def process_single_image_safe(processor, prompt, image):
    """Safely process a single image with error handling"""
    try:
        # First attempt - standard processing
        inputs = processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding=True
        )
        return inputs
    except Exception as e:
        # If that fails, try with explicit parameters
        try:
            # Set patch_size explicitly if needed
            if hasattr(processor, 'patch_size') and processor.patch_size is None:
                processor.patch_size = 14
            
            # Process with text as a list and image as a list
            inputs = processor(
                text=[prompt] if isinstance(prompt, str) else prompt,
                images=[image] if not isinstance(image, list) else image,
                return_tensors="pt",
                padding=True
            )
            return inputs
        except:
            # Last resort - resize image to standard size
            standard_size = (336, 336)
            image_resized = image.resize(standard_size, Image.LANCZOS)
            inputs = processor(
                text=[prompt] if isinstance(prompt, str) else prompt,
                images=[image_resized] if not isinstance(image_resized, list) else image_resized,
                return_tensors="pt",
                padding=True
            )
            return inputs

def get_image_embedding_e5v(image_path, model, processor, img_prompt):
    """Get image embedding using E5-V model"""
    try:
        # Open and convert image
        image = Image.open(image_path)
        
        # Check if image is valid
        if image is None:
            print(f"Warning: Image {image_path} could not be loaded")
            return None
        
        # Convert to RGB if necessary
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Check image dimensions
        if image.width == 0 or image.height == 0:
            print(f"Warning: Image {image_path} has invalid dimensions")
            return None
        
        # Verify image data
        try:
            img_array = np.array(image)
            if img_array.size == 0:
                print(f"Warning: Image {image_path} is empty")
                return None
        except:
            print(f"Warning: Image {image_path} cannot be converted to array")
            return None
        
        # Process the image with the prompt using safe processing
        inputs = process_single_image_safe(processor, img_prompt, image)
        
        # Move to device
        inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                 for k, v in inputs.items()}
        
        # Get embeddings
        with torch.no_grad():
            try:
                outputs = model(**inputs, output_hidden_states=True, return_dict=True)
                # Get the last hidden state of the last token
                img_emb = outputs.hidden_states[-1][:, -1, :]
                # Normalize the embedding
                img_emb = F.normalize(img_emb, dim=-1)
                return img_emb.cpu().numpy()[0]
            except RuntimeError as e:
                if "do not match" in str(e):
                    # Token mismatch error - skip this image
                    print(f"Token mismatch for {image_path}, skipping")
                    return None
                else:
                    raise e
                
    except Exception as e:
        print(f"Error processing image {image_path}: {type(e).__name__}: {e}")
        return None

def get_text_embedding_e5v(text, model, processor, text_prompt):
    """Get text embedding using E5-V model"""
    try:
        # Replace placeholder with actual text
        formatted_prompt = text_prompt.replace('<sent>', text)
        
        # Process text (no images for text-only input)
        # For text-only, we pass None or empty list for images
        inputs = processor(
            text=[formatted_prompt],
            images=None,
            return_tensors="pt",
            padding=True
        )
        
        # Move to device
        inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                 for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
            text_emb = outputs.hidden_states[-1][:, -1, :]
            text_emb = F.normalize(text_emb, dim=-1)
            
        return text_emb.cpu().numpy()[0]
    except Exception as e:
        print(f"Error processing text: {e}")
        return None

# Load E5-V model
model, processor, img_prompt, text_prompt = load_e5v_model()

# Read MultiHaystack.json
print("Loading dataset")
with open(INPUT_JSON, "r", encoding="utf-8") as f:
    haystack_data = json.load(f)

def collect_files(directory):
    """Collect all jpg, png, jpeg, txt files from directory"""
    valid_files = []
    file_mapping = {}  # Map full path to report path
    
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.png', '.jpeg', '.txt')):
                full_path = os.path.join(root, file)
                
                # Determine report path
                if root == directory:
                    # Files directly in dataset/
                    report_path = file
                else:
                    # Files in subdirectories
                    subdir = os.path.basename(root)
                    report_path = subdir  # Use subdirectory name only
                
                valid_files.append(full_path)
                file_mapping[full_path] = report_path
    
    return valid_files, file_mapping

valid_files, file_mapping = collect_files(DATASET_PATH)
print(f"Found {len(valid_files)} valid files")

def get_file_embedding(file_path):
    """Get embedding for file using E5-V model"""
    if file_path.lower().endswith(('.jpg', '.png', '.jpeg')):
        return get_image_embedding_e5v(file_path, model, processor, img_prompt)
    elif file_path.lower().endswith('.txt'):
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()
                # Limit text length to avoid issues
                if len(text) > 1000:
                    text = text[:1000]
                # Skip empty text files
                if not text.strip():
                    return None
            return get_text_embedding_e5v(text, model, processor, text_prompt)
        except Exception as e:
            print(f"Error reading text file {file_path}: {e}")
            return None
    return None

def process_files_in_batches(files):
    """Process files one by one to manage errors"""
    file_embeddings = {}
    skipped_files = 0
    processed_files = 0
    
    print("Processing dataset files")
    
    for idx, file_path in enumerate(tqdm(files, desc="Processing files")):
        embedding = get_file_embedding(file_path)
        if embedding is not None:
            file_embeddings[file_path] = embedding
            processed_files += 1
        else:
            skipped_files += 1
        
        # Clear memory periodically
        if (idx + 1) % 100 == 0:
            clear_memory()
            
        # Print progress periodically
        if (idx + 1) % 1000 == 0:
            print(f"Progress: Processed {processed_files} files, Skipped {skipped_files} files")
            if torch.cuda.is_available():
                print(f"GPU memory used: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    
    print(f"\nTotal: Processed {processed_files} files, Skipped {skipped_files} files")
    return file_embeddings

# Process all files
file_embeddings = process_files_in_batches(valid_files)
print(f"Successfully processed {len(file_embeddings)} files")

def get_question_embedding(question):
    """Get question embedding using E5-V model"""
    # Limit question length
    if len(question) > 500:
        question = question[:500]
    return get_text_embedding_e5v(question, model, processor, text_prompt)

def cosine_similarity(v1, v2):
    """Calculate cosine similarity"""
    dot_product = np.dot(v1, v2)
    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)
    
    # Avoid division by zero
    if norm_v1 == 0 or norm_v2 == 0:
        return 0.0
    
    return dot_product / (norm_v1 * norm_v2)

def get_top_k_similar(question_embedding, k=5):
    """Get top k files most similar to question, one per subfolder"""
    if len(file_embeddings) == 0:
        print("Warning: No file embeddings available")
        return []
    
    similarities = []
    for file_path, file_embedding in file_embeddings.items():
        similarity = cosine_similarity(question_embedding, file_embedding)
        similarities.append((file_path, similarity))
    
    # Sort by similarity descending
    similarities.sort(key=lambda x: x[1], reverse=True)
    
    # Ensure each subfolder appears only once
    top_k_results = []
    seen_folders = set()
    
    for file_path, similarity in similarities:
        # Get report path (might be subfolder name)
        report_path = file_mapping.get(file_path, file_path)
        
        # Check if we've already added this subfolder
        if report_path in seen_folders:
            continue
        
        # Add to results and mark subfolder as seen
        top_k_results.append((file_path, similarity))
        seen_folders.add(report_path)
        
        # Stop once we have enough results
        if len(top_k_results) >= k:
            break
    
    return top_k_results

def basename_without_ext(filename):
    """Get basename without extension"""
    known_extensions = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']
    base_name = os.path.basename(filename)
    
    for ext in known_extensions:
        if base_name.lower().endswith(ext):
            return base_name[:-len(ext)]
    
    return base_name

# Process questions and calculate similarities
results = []
correct_at_1 = 0
correct_at_3 = 0
correct_at_5 = 0
total = 0

print("Processing questions and calculating similarities")
for idx, item in enumerate(tqdm(haystack_data, desc="Processing questions")):
    # Get question
    question = ""
    answer = ""
    for conv in item.get("conversations", []):
        if conv.get("from") == "human":
            question = conv.get("value", "")
        elif conv.get("from") == "gpt":
            answer = conv.get("value", "")
    
    if not question:
        print(f"Warning: Empty question at index {idx}")
        continue
    
    # Get correct answer
    positive_files = item.get("positive", [])
    # Remove extensions for comparison
    positive_base_names = [basename_without_ext(p) for p in positive_files]
    
    # Get question embedding
    question_embedding = get_question_embedding(question)
    
    if question_embedding is None:
        print(f"Failed to get embedding for question: {question[:50]}")
        continue
    
    # Get top 5 similar files
    top_5 = get_top_k_similar(question_embedding, 5)
    
    if len(top_5) == 0:
        print(f"Warning: No similar files found for question {idx}")
        continue
    
    # Prepare result entry
    reported_top_5 = []
    found_at_1 = False
    found_at_3 = False
    found_at_5 = False
    
    for i, (file_path, similarity) in enumerate(top_5):
        # Get path to report
        report_path = file_mapping.get(file_path, file_path)
        reported_top_5.append(report_path)
        
        # Check if contains correct answer
        report_base_name = basename_without_ext(report_path)
        for pos_name in positive_base_names:
            if report_base_name == pos_name:
                if i == 0:
                    found_at_1 = True
                    found_at_3 = True
                    found_at_5 = True
                elif i < 3:
                    found_at_3 = True
                    found_at_5 = True
                else:
                    found_at_5 = True
    
    # Update counts
    if found_at_1:
        correct_at_1 += 1
    if found_at_3:
        correct_at_3 += 1
    if found_at_5:
        correct_at_5 += 1
    total += 1
    
    # Add to results list
    result_item = {
        "question": question,
        "positive": positive_files,
        "top_5_retrieved": reported_top_5,
        "answer": answer
    }
    results.append(result_item)
    
    # Clear memory and print progress periodically
    if (idx + 1) % 50 == 0:
        clear_memory()
        # Print intermediate results
        if total > 0:
            print(f"\nIntermediate results (after {total} questions):")
            print(f"Recall@1: {correct_at_1/total:.4f}")
            print(f"Recall@3: {correct_at_3/total:.4f}")
            print(f"Recall@5: {correct_at_5/total:.4f}")

# Calculate recall metrics
recall_at_1 = correct_at_1 / total if total > 0 else 0
recall_at_3 = correct_at_3 / total if total > 0 else 0
recall_at_5 = correct_at_5 / total if total > 0 else 0

# Build final result and save
final_result = {
    "model": MODEL_NAME,
    "metrics": {
        "recall@1": recall_at_1,
        "recall@3": recall_at_3,
        "recall@5": recall_at_5,
        "total_questions": total,
        "correct_at_1": correct_at_1,
        "correct_at_3": correct_at_3,
        "correct_at_5": correct_at_5
    },
    "results": results
}

with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
    json.dump(final_result, f, ensure_ascii=False, indent=2)

print(f"\n" + "="*50)
print(f"Processing complete! Results saved to {OUTPUT_JSON}")
print(f"="*50)
print(f"Total questions processed: {total}")
print(f"Recall@1: {recall_at_1:.4f} ({correct_at_1}/{total})")
print(f"Recall@3: {recall_at_3:.4f} ({correct_at_3}/{total})")
print(f"Recall@5: {recall_at_5:.4f} ({correct_at_5}/{total})")

# Final memory cleanup
clear_memory()
print("\nMemory cleared. Process complete!")