import os
import sys
import json
import uuid
import random
import logging
from tqdm import tqdm
from PIL import Image

import torch
from transformers import CLIPModel, CLIPProcessor
from filelock import FileLock
CLIP_MAX_TOKENS = 77
def initialize_clip_model(device: str = "cuda"):
    """
    Load CLIP model and processor from Hugging Face.
    """
    clip_model_id = "openai/clip-vit-base-patch32"
    processor = CLIPProcessor.from_pretrained(clip_model_id)
    model = CLIPModel.from_pretrained(clip_model_id).to(device)
    return processor, model

def compute_clip_similarity(
    image: Image.Image,
    prompt: str,
    processor: CLIPProcessor,
    model: CLIPModel,
    device: str = "cuda"
) -> float:
    """
    Compute the cosine similarity between image and text embeddings using CLIP.
    """
    prompt = truncate_prompt_tokens(prompt, processor) #to 77 tokens
    inputs = processor(
        text=[prompt],
        images=[image],
        return_tensors="pt",
        padding=True, truncation=True
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds

    # Normalize
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

    # Cosine similarity
    similarity = (image_embeds * text_embeds).sum(dim=-1).item()
    return similarity

def save_clip_score(
    key: str,
    prompt: str,
    image_path: str,
    similarity: float,
    output_file: str = "clip_scores.jsonl"
):
    """
    Thread-safe append to a JSONL file for each image's similarity score.
    """
    record = {
        "key": key,
        "prompt": prompt,
        "image_path": image_path,
        "similarity": similarity
    }
    lock_path = f"{output_file}.lock"
    with FileLock(lock_path):
        with open(output_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(record) + "\n")

def yield_keys_prompts(jsonl_file, use_original=False, skip_count=4000):
    """
    Example generator: 
    - Skips the first `skip_count` lines (default 4000).
    - Yields (key, [list_of_prompts]) for each line afterward.
    """
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        skipped = 0
        for line in f:
            skipped += 1
            if skipped <= skip_count:
                continue
            data = json.loads(line)
            entry = data['entry']
            key = entry['index']

            # If there's a single "result" 
            if "result" in data:
                prompts = [data['result']]
            else:
                # Possibly multiple results (result1, result2, etc.)
                results = (
                    [data['result1'], data['result2']]
                    if 'result2' in entry
                    else [data['result1']]
                )
                prompts = []
                for result in results:
                    prompt = (
                        result['extended']
                        if result.get('extended', None)
                        else result['generated']
                    )
                    prompts.append(prompt)

            # If "original" mode is used, override with the original caption
            if use_original:
                prompts = [entry["caption"]]

            yield key, prompts
def truncate_prompt_tokens(prompt: str, processor: CLIPProcessor) -> str:
    """
    Tokenize the prompt using CLIP’s tokenizer, and if it exceeds CLIP_MAX_TOKENS,
    truncate to the maximum. This mimics stable diffusion webui behavior.
    """
    tokens = processor.tokenizer.tokenize(prompt)
    # If longer than max tokens, truncate
    if len(tokens) > CLIP_MAX_TOKENS:
        tokens = tokens[:CLIP_MAX_TOKENS]
    # Reconstruct text
    truncated_prompt = processor.tokenizer.convert_tokens_to_string(tokens)
    return truncated_prompt

def get_folder_path(root_dir, key, num_folders=200):
    """
    Same logic as your generation script so we find 
    the correct subfolder for each key.
    """
    uuid_int = int.from_bytes(
        uuid.uuid5(uuid.NAMESPACE_DNS, key).bytes, 'big'
    )
    folder_index = uuid_int % num_folders
    folder_path = os.path.join(root_dir, f"folder_{folder_index+1}")
    return folder_path

def should_skip(key, num_workers, worker_index):
    """
    Worker-level sharding: only process keys that match this worker's index.
    """
    key_int = int.from_bytes(
        uuid.uuid5(uuid.NAMESPACE_DNS, key).bytes, 'big'
    )
    return key_int % num_workers != worker_index

def main():
    # Basic usage: clip_inference.py <jsonl_file> <root_dir> <num_folders> <num_workers> <worker_index> [original]
    if len(sys.argv) < 6:
        print("Usage: clip_inference.py <jsonl_file> <root_dir> <num_folders> <num_workers> <worker_index> [original]")
        sys.exit(1)

    jsonl_file = sys.argv[1]
    root_dir = sys.argv[2]
    num_folders = int(sys.argv[3])
    num_workers = int(sys.argv[4])
    worker_index = int(sys.argv[5])
    use_original = (len(sys.argv) > 6 and sys.argv[6] == "original")
    output_file = (len(sys.argv) > 7 and sys.argv[7]) or "clip_scores.jsonl"

    # Setup logging
    os.makedirs('logs', exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] [CLIP Worker] %(message)s',
        handlers=[
            logging.FileHandler("logs/clip_worker.log"),
            logging.StreamHandler()
        ]
    )

    # Determine device
    cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
    device = f"cuda:{cuda_visible}" if torch.cuda.is_available() else "cpu"
    logging.info(f"Running on device: {device}")

    # Initialize CLIP
    logging.info("Initializing CLIP model...")
    clip_processor, clip_model = initialize_clip_model(device=device)

    # Start processing
    logging.info(f"Starting CLIP inference on '{jsonl_file}'. Worker index: {worker_index}/{num_workers}")
    with tqdm(total=0, position=worker_index, desc=f"Worker {worker_index}") as pbar:
        for key, prompts in yield_keys_prompts(jsonl_file, use_original=use_original, skip_count=4000):
            # Worker-based sharding
            if should_skip(key, num_workers, worker_index):
                continue

            for prompt in prompts:
                # Determine the exact filename we used during generation
                folder_path = get_folder_path(root_dir, key, num_folders=num_folders)
                output_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, f"{key}_{prompt}").hex
                image_filename = f"{output_uuid}_image.webp"
                image_path = os.path.join(folder_path, image_filename)

                if not os.path.exists(image_path):
                    # If the image wasn't generated or doesn't exist, skip
                    logging.info(f"Image not found: {image_path} -- skipping.")
                    continue

                # Load the existing image
                try:
                    image = Image.open(image_path).convert("RGB")
                except Exception as e:
                    logging.error(f"Error opening image '{image_path}': {e}")
                    continue

                # Compute similarity
                sim_score = compute_clip_similarity(
                    image=image,
                    prompt=prompt,
                    processor=clip_processor,
                    model=clip_model,
                    device=device
                )

                # Save result to JSONL
                save_clip_score(
                    key=key,
                    prompt=prompt,
                    image_path=image_path,
                    similarity=sim_score,
                    output_file=output_file
                )

                pbar.update(1)

    logging.info("CLIP inference completed.")

if __name__ == "__main__":
    main()
