import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer


class Reranker:
    def __init__(self, model_name="Qwen/Qwen3-Reranker-0.6B"):
        """
        Initializes the Reranker with the specified Qwen model.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            attn_implementation="flash_attention_2",
        )

        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
        self.max_length = 8192

        prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
        suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
        self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)

    def _format_instruction(self, query, doc, instruction=None):
        """
        Formats the input for the reranker model.
        """
        if instruction is None:
            instruction = "Given a web search query, retrieve relevant passages that answer the query"
        return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

    def _preprocess_doc(self, knowledge_str):
        """
        Preprocesses a single string of knowledge into a list of passages,
        using "Title:" as the key separator.
        """
        processed_str = re.sub(r"Passage #\d+ ", "", knowledge_str)
        processed_str = re.sub(r"Passage #\d+ Text: ", "", processed_str)
        passages_split = re.split(r"Title:", processed_str)

        documents = []
        for passage_content in passages_split:
            if passage_content.strip():
                full_passage = "Title:" + passage_content.strip()
                documents.append(full_passage)
        return documents

    def _process_inputs(self, pairs):
        """
        Tokenizes and prepares the input pairs for the model.
        """
        inputs = self.tokenizer(
            pairs,
            padding=False,
            truncation="longest_first",
            return_attention_mask=False,
            max_length=self.max_length
            - len(self.prefix_tokens)
            - len(self.suffix_tokens),
        )
        for i, ele in enumerate(inputs["input_ids"]):
            inputs["input_ids"][i] = self.prefix_tokens + ele + self.suffix_tokens
        inputs = self.tokenizer.pad(
            inputs, padding=True, return_tensors="pt", max_length=self.max_length
        )
        for key in inputs:
            inputs[key] = inputs[key].to(self.model.device)
        return inputs

    @torch.no_grad()
    def _compute_logits(self, inputs):
        """
        Computes the relevance scores for the tokenized inputs.
        """
        batch_scores = self.model(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, self.token_true_id]
        false_vector = batch_scores[:, self.token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()
        return scores

    def rerank_documents(self, query, documents):
        """
        Reranks a list of documents or image-text pairs based on a query.
        """
        if isinstance(documents, list) and isinstance(documents[0], dict):
            doc_texts = [d.get("caption", "") for d in documents]
        if isinstance(documents, str):
            doc_texts = self._preprocess_doc(documents)
            documents = doc_texts

        pairs = [self._format_instruction(query, doc) for doc in doc_texts]
        inputs = self._process_inputs(pairs)
        scores = self._compute_logits(inputs)

        scored_documents = sorted(
            zip(documents, scores), key=lambda x: x[1], reverse=True
        )
        reranked_items = [doc for doc, score in scored_documents]
        top_half_items = reranked_items[: len(reranked_items) // 2]

        # --- Formatting Output based on Input Type ---
        if isinstance(documents, list) and isinstance(documents[0], dict):
            return top_half_items
        else:
            formatted_passages = []
            for i, passage in enumerate(top_half_items, 1):
                content = passage.replace("Title:", "").strip()
                formatted_passages.append(f"Passage #{i} Title: {content}")
            return "\n\n".join(formatted_passages)


if __name__ == "__main__":
    reranker = Reranker()

    # --- Reranking Text Knowledge (as a single string) ---
    query_text = "What are the benefits of exercise?"
    knowledge_string = """Passage #1 Title: The Importance of Physical Activity
Passage #1 Text: Regular exercise has numerous benefits for your physical and mental health. It can help with weight management, reduce the risk of chronic diseases, and improve your mood.

Passage #2 Title: Exercise and Mental Health
Passage #2 Text: Physical activity is known to boost endorphins, which are natural mood lifters. It can also help reduce symptoms of anxiety and depression.

Passage #3 Title: How to Start an Exercise Routine
Passage #3 Text: To begin an exercise routine, it's important to start slow and gradually increase the intensity. Choose activities you enjoy to make it a sustainable habit.

Passage #4 Title: The Risks of a Sedentary Lifestyle
Passage #4 Text: A lack of physical activity can lead to various health problems, including obesity, heart disease, and type 2 diabetes."""

    # Pass the raw string directly to the reranker
    reranked_text = reranker.rerank_documents(query_text, knowledge_string)
    print("--- Reranked Text Knowledge (from single string) ---")
    print(reranked_text)

    # --- Reranking Image-Text Pairs (as a list of dicts) ---
    query_image = "Photos of happy dogs"
    image_text_pairs = [
        {
            "image_path": "path/to/image1.jpg",
            "caption": "A golden retriever playing fetch in a sunny park.",
        },
        {
            "image_path": "path/to/image2.jpg",
            "caption": "A small puppy sleeping in a basket.",
        },
        {
            "image_path": "path/to/image3.jpg",
            "caption": "A group of different dog breeds running on a beach.",
        },
        {
            "image_path": "path/to/image4.jpg",
            "caption": "A cat sitting on a windowsill.",
        },
        {
            "image_path": "path/to/image5.jpg",
            "caption": "A joyful corgi with its tongue out, looking at the camera.",
        },
    ]

    reranked_images = reranker.rerank_documents(query_image, image_text_pairs)
    print("--- Reranked Image-Text Pairs ---")
    for i, pair in enumerate(reranked_images):
        print(f"Rank {i+1}: {pair}\n")
