# reranker_interface.py (Revised for Index-based Access)
import concurrent.futures
import os
import random
from abc import ABC, abstractmethod
from typing import List

import torch
from openai import OpenAI

from .rank_gpt import permutation_pipeline
from .universal_data_loader import (
    UniversalBanditDataLoader,
)  # Reference the loader

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
BASE_URL = "https://api.openai.com/v1"


class Reranker(ABC):
    """Abstract base class for a reranker that operates on indices."""

    def __init__(
        self, data_loader: UniversalBanditDataLoader, device: torch.device
    ):
        """
        Initializes the reranker with a data loader to fetch text.

        Args:
            data_loader: An initialized UniversalBanditDataLoader instance.
            device: The torch device to use for tensor operations.
        """
        self.data_loader = data_loader
        self.device = device

    @abstractmethod
    def rerank(
        self,
        batch_query_indices: torch.Tensor,
        batch_candidate_arm_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Reranks candidates based on query and arm indices.

        Args:
            batch_query_indices: Tensor of query indices. Shape: [B]
            batch_candidate_arm_indices: Tensor of candidate arm indices.
                                         Shape: [B, k]

        Returns:
            A tensor of shape [B] containing the index (from 0 to k-1)
            of the top-ranked arm for each query.
        """
        pass


class RandomReranker(Reranker):
    """
    A placeholder reranker that demonstrates fetching text before making a
    random selection.
    """

    def rerank(
        self,
        batch_query_indices: torch.Tensor,
        batch_candidate_arm_indices: torch.Tensor,
    ) -> torch.Tensor:

        # --- Example of fetching text ---
        # 1. Get query texts
        query_texts = self.data_loader.get_texts_by_indices(
            query_indices=batch_query_indices.tolist()
        )["queries"]

        # 2. Get candidate arm texts
        batch_candidate_arms_text = []
        for i in range(batch_candidate_arm_indices.size(0)):
            arm_indices = batch_candidate_arm_indices[i].tolist()
            texts = self.data_loader.get_texts_by_indices(
                arm_indices=arm_indices
            )["arms"]
            batch_candidate_arms_text.append(texts)

        # A real reranker would now use `query_texts` and
        # `batch_candidate_arms_text` in a model (e.g., a Cross-Encoder).
        # Here, we just proceed with random selection.

        batch_size = batch_candidate_arm_indices.size(0)
        k = batch_candidate_arm_indices.size(1)

        random_indices = torch.randint(0, k, (batch_size,), device=self.device)
        return random_indices


class LLMReranker(Reranker):
    """
    A reranker that uses a large language model (LLM) to rerank candidates.
    """

    def __init__(
        self,
        data_loader: UniversalBanditDataLoader,
        device: torch.device,
        model_id: str = "gpt-4.1-nano",
        base_url: str = BASE_URL,
        api_key: str = OPENAI_API_KEY,
    ):
        super().__init__(data_loader, device)
        self.llm_client = OpenAI(api_key=api_key, base_url=base_url)
        self.model_id = model_id

    def rerank(
        self,
        batch_query_indices: torch.Tensor,
        batch_candidate_arm_indices: torch.Tensor,
    ) -> torch.Tensor:
        query_texts = self.data_loader.get_texts_by_indices(
            query_indices=batch_query_indices.tolist()
        )["queries"]

        batch_candidate_arms_text = []
        for i in range(batch_candidate_arm_indices.size(0)):
            arm_indices = batch_candidate_arm_indices[i].tolist()
            texts = self.data_loader.get_texts_by_indices(
                arm_indices=arm_indices
            )["arms"]
            batch_candidate_arms_text.append(texts)

        # for each query and its B candidate arm texts, we create an item as:
        #   {"query": query_text[i], "hits": [{"content": arm_text_ij}]}
        items = []
        for i in range(len(query_texts)):
            item = {"query": query_texts[i], "hits": []}
            for j in range(len(batch_candidate_arms_text[i])):
                item["hits"].append(
                    {"content": batch_candidate_arms_text[i][j]}
                )
            items.append(item)

        def reranking_with_retry(item):
            for _ in range(3):  # retry up to 3 times
                try:
                    return permutation_pipeline(
                        item=item,
                        model_name=self.model_id,
                        api_key=self.llm_client.api_key,
                        base_url=self.llm_client.base_url,
                    )
                except Exception as e:
                    print(f"Reranking failed: {e}. Retrying...")
            # return a randomized permutation if all retries fail
            # note the `permutation_pipeline` returns exact the same format with the input, but in different orders
            random.shuffle(item["hits"])
            return item

        # rerank each item with concurrent processing
        with concurrent.futures.ThreadPoolExecutor() as executor:
            reranked_items = list(executor.map(reranking_with_retry, items))

        # extract the reranked indices
        reranked_indices = []
        for i, item in enumerate(reranked_items):
            # Create a map from content to original index (0, 1, 2...)
            original_index_map = {
                content: original_idx
                for original_idx, content in enumerate(
                    batch_candidate_arms_text[i]
                )
            }

            # Get the list of original indices in their new, reranked order
            # For our example, this list will be [2, 0, 1]
            reranked_original_indices = [
                original_index_map[hit["content"]] for hit in item["hits"]
            ]

            # Now, create the final list of ranks
            n_arms = len(reranked_original_indices)
            ranks = [0] * n_arms
            for new_rank, original_index in enumerate(
                reranked_original_indices
            ):
                ranks[original_index] = new_rank

            # For our example:
            # 1. new_rank=0, original_index=2 -> ranks[2] = 0
            # 2. new_rank=1, original_index=0 -> ranks[0] = 1
            # 3. new_rank=2, original_index=1 -> ranks[1] = 2
            # The final 'ranks' list is [1, 2, 0], which is correct.
            reranked_indices.append(ranks)

        return torch.tensor(reranked_indices, device=self.device)
