import os
import json
import torch
import open_clip
from PIL import Image
import numpy as np
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
from pathlib import Path

# Configuration
MODEL_NAME = 'hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K'
DATASET_PATH = "../dataset/dataset_full/"
INPUT_JSON = "Multihaystack.json"
OUTPUT_JSON = "openclip_results.json"
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.txt'}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")

# Model initialization
model, preprocess = open_clip.create_model_from_pretrained(MODEL_NAME)
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
model.eval()
model = model.to(DEVICE)

# File discovery
def find_all_files(root_dir):
    all_files = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            file_path = os.path.join(dirpath, filename)
            file_ext = os.path.splitext(filename)[1].lower()
            if file_ext in SUPPORTED_FORMATS:
                all_files.append(file_path)
    return all_files

# Basename extraction
def get_file_basename(file_path):
    file_ext = os.path.splitext(file_path)[1].lower()
    if file_ext in {'.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt'}:
        return os.path.splitext(file_path)[0]
    return os.path.basename(file_path)

# Directory name extraction
def get_dir_name(file_path):
    path = Path(file_path)
    dataset_path = Path(DATASET_PATH).resolve()
    file_parent = path.parent.resolve()

    # File in dataset root
    if file_parent == dataset_path:
        return os.path.basename(file_path)

    # File in subfolder
    try:
        relative_path = file_parent.relative_to(dataset_path)
        return relative_path.parts[0]
    except ValueError:
        return os.path.basename(file_path)

# Embedding generation
def process_files(file_paths):
    print(f"Processing {len(file_paths)} files")
    file_embeddings = []
    file_names = []

    for file_path in tqdm(file_paths):
        try:
            file_ext = os.path.splitext(file_path)[1].lower()

            if file_ext in {'.jpg', '.jpeg', '.png'}:
                img = Image.open(file_path).convert('RGB')
                img_tensor = preprocess(img).unsqueeze(0).to(DEVICE)

                with torch.no_grad(), torch.cuda.amp.autocast():
                    embedding = model.encode_image(img_tensor)
                    embedding /= embedding.norm(dim=-1, keepdim=True)

            elif file_ext == '.txt':
                with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                    text_content = f.read()

                text_tokens = tokenizer([text_content]).to(DEVICE)
                with torch.no_grad(), torch.cuda.amp.autocast():
                    embedding = model.encode_text(text_tokens)
                    embedding /= embedding.norm(dim=-1, keepdim=True)
            else:
                continue

            file_embeddings.append(embedding.cpu().numpy())
            file_names.append(file_path)

        except Exception as e:
            print(f"Error processing file {file_path}: {e}")

    return file_embeddings, file_names

# Text embedding generation
def encode_text(text):
    text_tokens = tokenizer([text]).to(DEVICE)
    with torch.no_grad(), torch.cuda.amp.autocast():
        question_embedding = model.encode_text(text_tokens)
        question_embedding /= question_embedding.norm(dim=-1, keepdim=True)
    return question_embedding.cpu().numpy()

# Similarity computation
def get_top_k_results(text_embedding, file_embeddings, file_names, k=5):
    similarities = np.dot(file_embeddings, text_embedding.T).squeeze()

    # Sort by similarity
    sorted_indices = np.argsort(similarities)[::-1]

    # Find unique directories
    top_files = []
    top_similarities = []
    seen_dirs = set()

    for idx in sorted_indices:
        dir_name = get_dir_name(file_names[idx])

        # Skip duplicates
        if dir_name in seen_dirs:
            continue

        # Add result
        top_files.append(file_names[idx])
        top_similarities.append(similarities[idx])
        seen_dirs.add(dir_name)

        # Stop when limit reached
        if len(top_files) >= k:
            break

    return top_files, top_similarities

def main():
    # File discovery
    file_paths = find_all_files(DATASET_PATH)
    if not file_paths:
        print(f"No files found in {DATASET_PATH}")
        return

    # Generate embeddings
    file_embeddings, file_names = process_files(file_paths)
    file_embeddings = np.vstack(file_embeddings)

    # Load dataset
    try:
        with open(INPUT_JSON, 'r') as f:
            haystack_data = json.load(f)
    except FileNotFoundError:
        print(f"File {INPUT_JSON} not found")
        return

    # Process questions
    results = []
    correct_at_1 = 0
    correct_at_3 = 0
    correct_at_5 = 0
    total_queries = 0

    for item in tqdm(haystack_data):
        # Extract data
        question = None
        answer = None
        for conv in item["conversations"]:
            if conv["from"] == "human":
                question = conv["value"]
            elif conv["from"] == "gpt":
                answer = conv["value"]

        if not question:
            continue

        # Encode question
        question_embedding = encode_text(question)

        # Retrieve results
        top_files, similarities = get_top_k_results(question_embedding, file_embeddings, file_names, k=5)

        # Format results
        top_retrieved = []
        for file_path in top_files:
            result_name = get_dir_name(file_path)
            top_retrieved.append(result_name)

        # Check accuracy
        positive_files = item.get("positive", [])

        is_in_top_1 = False
        is_in_top_3 = False
        is_in_top_5 = False

        for pos_file in positive_files:
            pos_basename = get_file_basename(pos_file)

            # Check top-1
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:1]):
                is_in_top_1 = True

            # Check top-3
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:3]):
                is_in_top_3 = True

            # Check top-5
            if any(pos_basename == get_file_basename(result) for result in top_retrieved[:5]):
                is_in_top_5 = True

        if is_in_top_1:
            correct_at_1 += 1
        if is_in_top_3:
            correct_at_3 += 1
        if is_in_top_5:
            correct_at_5 += 1

        total_queries += 1

        # Store result
        result_item = {
            "question": question,
            "positive": positive_files,
            "top_5_retrieved": top_retrieved,
            "answer": answer
        }

        results.append(result_item)

    # Calculate metrics
    recall_at_1 = correct_at_1 / total_queries if total_queries > 0 else 0
    recall_at_3 = correct_at_3 / total_queries if total_queries > 0 else 0
    recall_at_5 = correct_at_5 / total_queries if total_queries > 0 else 0

    # Create output
    output = {
        "metrics": {
            "recall@1": recall_at_1,
            "recall@3": recall_at_3,
            "recall@5": recall_at_5
        },
        "results": results
    }

    # Save results
    with open(OUTPUT_JSON, 'w') as f:
        json.dump(output, f, indent=2)

    print(f"Results saved to {OUTPUT_JSON}")
    print(f"Recall@1: {recall_at_1:.4f}")
    print(f"Recall@3: {recall_at_3:.4f}")
    print(f"Recall@5: {recall_at_5:.4f}")

if __name__ == "__main__":
    main()