import os
import json
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
from pathlib import Path

# Configuration
MODEL_NAME = 'jinaai/jina-clip-v2'
DATASET_PATH = "../dataset/dataset_full/"
INPUT_JSON = "Multihaystack.json"
OUTPUT_JSON = "jinav2_results.json"
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.txt'}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")

# Model initialization
try:
    model = AutoModel.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    )
    print("Model loaded with optimized attention")
except Exception as e:
    print(f"Failed to load with optimized attention: {e}")
    print("Fallback to eager attention")
    model = AutoModel.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        attn_implementation="eager"
    )
    print("Model loaded with eager attention")

model = model.to(DEVICE)
model.eval()

# File discovery
def find_all_files(root_dir):
    all_files = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            file_path = os.path.join(dirpath, filename)
            file_ext = os.path.splitext(filename)[1].lower()
            if file_ext in SUPPORTED_FORMATS:
                all_files.append(file_path)
    return all_files

# Basename extraction
def get_file_basename(file_path):
    file_ext = os.path.splitext(file_path)[1].lower()
    if file_ext in {'.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt'}:
        return os.path.splitext(file_path)[0]
    return os.path.basename(file_path)

# Directory name extraction
def get_dir_name(file_path):
    path = Path(file_path)
    dataset_path = Path(DATASET_PATH).resolve()
    file_parent = path.parent.resolve()

    # File in dataset root
    if file_parent == dataset_path:
        return os.path.basename(file_path)

    # File in subfolder
    try:
        relative_path = file_parent.relative_to(dataset_path)
        return relative_path.parts[0]
    except ValueError:
        return os.path.basename(file_path)

# Embedding generation
def process_files(file_paths):
    print(f"Processing {len(file_paths)} files")
    file_embeddings = []
    file_names = []

    for file_path in tqdm(file_paths):
        try:
            file_ext = os.path.splitext(file_path)[1].lower()

            if file_ext in {'.jpg', '.jpeg', '.png'}:
                embedding = model.encode_image(file_path, truncate_dim=None)
                embedding_tensor = torch.tensor(embedding).reshape(1, -1).to(DEVICE)
            elif file_ext == '.txt':
                with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                    text_content = f.read()
                embedding = model.encode_text(text_content, truncate_dim=None)
                embedding_tensor = torch.tensor(embedding).reshape(1, -1).to(DEVICE)
            else:
                continue

            file_embeddings.append(embedding_tensor)
            file_names.append(file_path)

        except Exception as e:
            print(f"Error processing file {file_path}: {e}")

    return file_embeddings, file_names

# Text embedding generation
def encode_text(text):
    query_embedding = model.encode_text(text, task='retrieval.query', truncate_dim=None)
    query_embedding = torch.tensor(query_embedding).reshape(1, -1).to(DEVICE)
    return query_embedding

# Similarity computation
def get_top_k_results(text_embedding, file_embeddings, file_names, k=5):
    similarities = torch.matmul(text_embedding, file_embeddings.T).squeeze().tolist()

    # Combine results
    file_sim_pairs = list(zip(file_names, similarities))
    sorted_pairs = sorted(file_sim_pairs, key=lambda x: x[1], reverse=True)

    # Find unique directories
    top_files = []
    top_similarities = []
    seen_dirs = set()

    for file_path, sim in sorted_pairs:
        dir_name = get_dir_name(file_path)

        # Skip duplicates
        if dir_name in seen_dirs:
            continue

        # Add result
        top_files.append(file_path)
        top_similarities.append(sim)
        seen_dirs.add(dir_name)

        # Stop when limit reached
        if len(top_files) >= k:
            break

    return top_files, top_similarities

def main():
    # File discovery
    file_paths = find_all_files(DATASET_PATH)
    if not file_paths:
        print(f"No files found in {DATASET_PATH}")
        return

    # Generate embeddings
    file_embeddings, file_names = process_files(file_paths)
    file_embeddings = torch.cat(file_embeddings, dim=0).to(DEVICE)

    # Normalize features
    file_embeddings = file_embeddings / torch.norm(file_embeddings, dim=1, keepdim=True)

    # Load dataset
    try:
        with open(INPUT_JSON, 'r') as f:
            haystack_data = json.load(f)
    except FileNotFoundError:
        print(f"File {INPUT_JSON} not found")
        return

    # Process questions
    results = []
    correct_at_1 = 0
    correct_at_3 = 0
    correct_at_5 = 0
    total_queries = 0

    for item in tqdm(haystack_data):
        # Extract data
        question = None
        answer = None
        for conv in item["conversations"]:
            if conv["from"] == "human":
                question = conv["value"]
            elif conv["from"] == "gpt":
                answer = conv["value"]

        if not question:
            continue

        # Encode question
        question_embedding = encode_text(question)
        question_embedding = question_embedding / torch.norm(question_embedding, dim=1, keepdim=True)

        # Retrieve results
        top_files, similarities = get_top_k_results(question_embedding, file_embeddings, file_names, k=5)

        # Format results
        top_retrieved = []
        for file_path in top_files:
            result_name = get_dir_name(file_path)
            top_retrieved.append(result_name)

        # Check accuracy
        positive_files = item.get("positive", [])

        is_in_top_1 = False
        is_in_top_3 = False
        is_in_top_5 = False

        for pos_file in positive_files:
            pos_basename = get_file_basename(pos_file)

            # Check top-1
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:1]):
                is_in_top_1 = True

            # Check top-3
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:3]):
                is_in_top_3 = True

            # Check top-5
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:5]):
                is_in_top_5 = True

        if is_in_top_1:
            correct_at_1 += 1
        if is_in_top_3:
            correct_at_3 += 1
        if is_in_top_5:
            correct_at_5 += 1

        total_queries += 1

        # Store result
        result_item = {
            "question": question,
            "positive": positive_files,
            "top_5_retrieved": top_retrieved,
            "answer": answer
        }

        results.append(result_item)

    # Calculate metrics
    recall_at_1 = correct_at_1 / total_queries if total_queries > 0 else 0
    recall_at_3 = correct_at_3 / total_queries if total_queries > 0 else 0
    recall_at_5 = correct_at_5 / total_queries if total_queries > 0 else 0

    # Create output
    output = {
        "metrics": {
            "recall@1": recall_at_1,
            "recall@3": recall_at_3,
            "recall@5": recall_at_5
        },
        "results": results
    }

    # Save results
    with open(OUTPUT_JSON, 'w') as f:
        json.dump(output, f, indent=2)

    print(f"Results saved to {OUTPUT_JSON}")
    print(f"Recall@1: {recall_at_1:.4f}")
    print(f"Recall@3: {recall_at_3:.4f}")
    print(f"Recall@5: {recall_at_5:.4f}")

if __name__ == "__main__":
    main()