import os
import sys
import json
import logging
from filelock import FileLock

import torch
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm

CLIP_MAX_TOKENS = 77

def initialize_clip_model(device: str = "cuda"):
    """
    Load CLIP model and processor from Hugging Face.
    """
    clip_model_id = "openai/clip-vit-base-patch32"
    processor = CLIPProcessor.from_pretrained(clip_model_id)
    model = CLIPModel.from_pretrained(clip_model_id).to(device)
    return processor, model

def batch_prompt_tokens(prompt: str, processor: CLIPProcessor) -> list:
    """
    Splits the given prompt into multiple segments, each at most CLIP_MAX_TOKENS.
    Returns a list of chunked strings that can be passed as a batch to the processor.
    """
    tokens = processor.tokenizer.tokenize(prompt)
    chunks = []
    for i in range(0, len(tokens), CLIP_MAX_TOKENS):
        chunk_tokens = tokens[i : i + CLIP_MAX_TOKENS]
        chunk_str = processor.tokenizer.convert_tokens_to_string(chunk_tokens)
        chunks.append(chunk_str)
    return chunks

def compute_text_to_text_similarity(
    text1: str,
    text2: str,
    processor: CLIPProcessor,
    model: CLIPModel,
    device: str = "cuda"
) -> float:
    """
    Compute the cosine similarity between two pieces of text using CLIP's text encoder,
    chunking the texts if they exceed 77 tokens.
    """
    # Break each text into multiple chunks of up to CLIP_MAX_TOKENS tokens
    text1_chunks = batch_prompt_tokens(text1, processor)
    text2_chunks = batch_prompt_tokens(text2, processor)

    # Process text1 chunks
    inputs_1 = processor(
        text=text1_chunks,
        images=None,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(device)

    # Process text2 chunks
    inputs_2 = processor(
        text=text2_chunks,
        images=None,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(device)

    # Embed both sets of chunks
    with torch.no_grad():
        text1_embeds = model.get_text_features(**inputs_1)  # [num_chunks1, hidden_dim]
        text2_embeds = model.get_text_features(**inputs_2)  # [num_chunks2, hidden_dim]

    # Average embeddings across chunks to get a single vector for text1 and text2
    text1_embeds = text1_embeds.mean(dim=0, keepdim=True)  # [1, hidden_dim]
    text2_embeds = text2_embeds.mean(dim=0, keepdim=True)  # [1, hidden_dim]

    # Normalize embeddings
    text1_embeds = text1_embeds / text1_embeds.norm(p=2, dim=-1, keepdim=True)
    text2_embeds = text2_embeds / text2_embeds.norm(p=2, dim=-1, keepdim=True)

    # Cosine similarity (dot product after normalization)
    similarity = (text1_embeds * text2_embeds).sum(dim=-1).item()
    return similarity

def save_text_text_score(
    key: str,
    caption: str,
    result: str,
    similarity: float,
    output_file: str = "text_text_scores.jsonl"
):
    """
    Thread-safe append to a JSONL file for each text-to-text similarity score.
    """
    record = {
        "key": key,
        "caption": caption,
        "result": result,
        "similarity": similarity
    }
    lock_path = f"{output_file}.lock"
    with FileLock(lock_path):
        with open(output_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(record) + "\n")

def main():
    """
    Example usage:
        python text_text_inference.py input.jsonl output.jsonl
    """
    if len(sys.argv) < 3:
        print("Usage: text_text_inference.py <input_jsonl> <output_jsonl>")
        sys.exit(1)

    input_jsonl = sys.argv[1]
    output_file = sys.argv[2]

    # Setup logging
    os.makedirs('logs', exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] [Text2Text] %(message)s',
        handlers=[
            logging.FileHandler("logs/text2text_worker.log"),
            logging.StreamHandler()
        ]
    )

    # Decide device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logging.info(f"Running on device: {device}")

    # Initialize CLIP
    logging.info("Initializing CLIP model...")
    clip_processor, clip_model = initialize_clip_model(device=device)

    # Process each line in the input JSONL
    logging.info(f"Starting text-to-text inference on '{input_jsonl}'.")
    with open(input_jsonl, 'r', encoding='utf-8') as f, tqdm(desc="Text2Text Inference") as pbar:
        for line in f:
            data = json.loads(line.strip())

            # Each line is expected to have data["entry"]["caption"] and data["result"]
            key = data["entry"]["index"]
            caption = data["entry"]["caption"]
            result = data.get("result") or data["result1"]["generated"]

            # Compute similarity
            sim_score = compute_text_to_text_similarity(
                text1=caption,
                text2=result,
                processor=clip_processor,
                model=clip_model,
                device=device
            )

            # Save to JSONL
            save_text_text_score(
                key=key,
                caption=caption,
                result=result,
                similarity=sim_score,
                output_file=output_file
            )

            pbar.update(1)

    logging.info("Text-to-text inference completed.")

if __name__ == "__main__":
    main()
    