import os
import json
import torch
import numpy as np
from PIL import Image
from pathlib import Path
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel

# Model initialization
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Configuration
DATASET_PATH = "../dataset/dataset_full/"
MULTIHAYSTACK_JSON = "Multihaystack.json"
OUTPUT_JSON = "clip_results.json"
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.txt'}

# File discovery
def find_all_files(root_dir):
    all_files = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            file_path = os.path.join(dirpath, filename)
            file_ext = os.path.splitext(filename)[1].lower()
            if file_ext in SUPPORTED_FORMATS:
                all_files.append(file_path)
    return all_files

# Basename extraction
def get_file_basename(file_path):
    file_ext = os.path.splitext(file_path)[1].lower()
    if file_ext in {'.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt'}:
        return os.path.splitext(file_path)[0]
    return os.path.basename(file_path)

# Directory name extraction
def get_dir_name(file_path):
    path = Path(file_path)
    dataset_path = Path(DATASET_PATH).resolve()
    file_parent = path.parent.resolve()
    
    # File in dataset root
    if file_parent == dataset_path:
        return os.path.basename(file_path)
    
    # File in subfolder
    # Get relative path
    try:
        relative_path = file_parent.relative_to(dataset_path)
        # Return subfolder name
        return relative_path.parts[0]
    except ValueError:
        # Fallback to filename
        return os.path.basename(file_path)

# Embedding generation
def process_files(file_paths):
    print(f"Processing {len(file_paths)} files")
    file_embeddings = []
    file_names = []
    
    for file_path in tqdm(file_paths):
        try:
            file_ext = os.path.splitext(file_path)[1].lower()
            
            # Process images
            if file_ext in {'.jpg', '.jpeg', '.png'}:
                image = Image.open(file_path).convert('RGB')
                inputs = processor(images=image, return_tensors="pt")
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                # Extract features
                with torch.no_grad():
                    features = model.get_image_features(**inputs)
            
            # Process text files
            elif file_ext == '.txt':
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        text_content = f.read()
                    
                    inputs = processor(text=text_content, return_tensors="pt", padding=True)
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                    
                    # Extract features
                    with torch.no_grad():
                        features = model.get_text_features(**inputs)
                except:
                    # Fallback encoding
                    with open(file_path, 'r', encoding='latin-1') as f:
                        text_content = f.read()
                    
                    inputs = processor(text=text_content, return_tensors="pt", padding=True)
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                    
                    # Extract features
                    with torch.no_grad():
                        features = model.get_text_features(**inputs)
            
            # Normalize features
            features = features / features.norm(dim=-1, keepdim=True)
            
            file_embeddings.append(features.cpu().numpy())
            file_names.append(file_path)
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    return file_embeddings, file_names

# Text embedding generation
def encode_text(text):
    inputs = processor(text=text, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Extract features
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)
    
    # Normalize
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    return text_features.cpu().numpy()

# Similarity computation
def get_top_k_results(text_embedding, file_embeddings, file_names, k=5):
    similarities = []
    for file_emb in file_embeddings:
        similarity = np.dot(text_embedding, file_emb.T)
        similarities.append(float(similarity))
    
    # Combine file paths with similarities
    file_sim_pairs = list(zip(file_names, similarities))
    
    # Sort by similarity descending
    sorted_pairs = sorted(file_sim_pairs, key=lambda x: x[1], reverse=True)
    
    # Store results
    top_files = []
    top_similarities = []
    seen_dirs = set()
    
    # Iterate through sorted files to find k unique directories
    for file_path, sim in sorted_pairs:
        dir_name = get_dir_name(file_path)
        
        # Skip if directory already in results
        if dir_name in seen_dirs:
            continue
        
        # Add to results
        top_files.append(file_path)
        top_similarities.append(sim)
        seen_dirs.add(dir_name)
        
        # Stop if k unique directories found
        if len(top_files) >= k:
            break
    
    return top_files, top_similarities

def main():
    # Find all relevant files
    file_paths = find_all_files(DATASET_PATH)
    if not file_paths:
        print(f"No files found in {DATASET_PATH}. Please check the directory.")
        return
    
    # Process all files to get embeddings
    file_embeddings, file_names = process_files(file_paths)
    file_embeddings = np.vstack([emb.squeeze(0) for emb in file_embeddings])
    
    # Load MultiHaystack.json
    try:
        with open(MULTIHAYSTACK_JSON, 'r') as f:
            haystack_data = json.load(f)
    except FileNotFoundError:
        print(f"File {MULTIHAYSTACK_JSON} not found. Please make sure it exists in the current directory.")
        return
    
    # Process each question and compute embeddings
    results = []
    correct_at_1 = 0
    correct_at_3 = 0
    correct_at_5 = 0
    total_queries = 0
    
    for item in tqdm(haystack_data):
        # Extract question
        question = None
        answer = None
        for conv in item["conversations"]:
            if conv["from"] == "human":
                question = conv["value"]
            elif conv["from"] == "gpt":
                answer = conv["value"]
        
        if not question:
            continue
        
        # Encode question
        question_embedding = encode_text(question)
        
        # Get top 5 results
        top_files, similarities = get_top_k_results(question_embedding, file_embeddings, file_names, k=5)
        
        # Format retrieved results for output
        top_retrieved = []
        for file_path in top_files:
            result_name = get_dir_name(file_path)
            top_retrieved.append(result_name)
        
        # Check if positive examples are in top k
        positive_files = item.get("positive", [])
        
        is_in_top_1 = False
        is_in_top_3 = False
        is_in_top_5 = False
        
        for pos_file in positive_files:
            pos_basename = get_file_basename(pos_file)
            
            # Check top 1
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:1]):
                is_in_top_1 = True
            
            # Check top 3
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:3]):
                is_in_top_3 = True
            
            # Check top 5
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:5]):
                is_in_top_5 = True
        
        if is_in_top_1:
            correct_at_1 += 1
        if is_in_top_3:
            correct_at_3 += 1
        if is_in_top_5:
            correct_at_5 += 1
        
        total_queries += 1
        
        # Add to results
        result_item = {
            "question": question,
            "positive": positive_files,
            "top_5_retrieved": top_retrieved,
            "answer": answer
        }
        
        results.append(result_item)
    
    # Calculate recall metrics
    recall_at_1 = correct_at_1 / total_queries if total_queries > 0 else 0
    recall_at_3 = correct_at_3 / total_queries if total_queries > 0 else 0
    recall_at_5 = correct_at_5 / total_queries if total_queries > 0 else 0
    
    # Create final output
    output = {
        "metrics": {
            "recall@1": recall_at_1,
            "recall@3": recall_at_3,
            "recall@5": recall_at_5
        },
        "results": results
    }
    
    # Save results
    with open(OUTPUT_JSON, 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"Results saved to {OUTPUT_JSON}")
    print(f"Recall@1: {recall_at_1:.4f}")
    print(f"Recall@3: {recall_at_3:.4f}")
    print(f"Recall@5: {recall_at_5:.4f}")

if __name__ == "__main__":
    main()