from __future__ import annotations

import os
import aiohttp
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field

from .utils import logger


class RerankModel(BaseModel):
    """
    Wrapper for rerank functions that can be used with LightRAG.

    Example usage:
    ```python
    from lightrag.rerank import RerankModel, jina_rerank

    # Create rerank model
    rerank_model = RerankModel(
        rerank_func=jina_rerank,
        kwargs={
            "model": "BAAI/bge-reranker-v2-m3",
            "api_key": "your_api_key_here",
            "base_url": "https://api.jina.ai/v1/rerank"
        }
    )

    # Use in LightRAG
    rag = LightRAG(
        rerank_model_func=rerank_model.rerank,
        # ... other configurations
    )

    # Query with rerank enabled (default)
    result = await rag.aquery(
        "your query",
        param=QueryParam(enable_rerank=True)
    )
    ```

    Or define a custom function directly:
    ```python
    async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
        return await jina_rerank(
            query=query,
            documents=documents,
            model="BAAI/bge-reranker-v2-m3",
            api_key="your_api_key_here",
            top_n=top_n or 10,
            **kwargs
        )

    rag = LightRAG(
        rerank_model_func=my_rerank_func,
        # ... other configurations
    )

    # Control rerank per query
    result = await rag.aquery(
        "your query",
        param=QueryParam(enable_rerank=True)  # Enable rerank for this query
    )
    ```
    """

    rerank_func: Callable[[Any], List[Dict]]
    kwargs: Dict[str, Any] = Field(default_factory=dict)

    async def rerank(
        self,
        query: str,
        documents: List[Dict[str, Any]],
        top_n: Optional[int] = None,
        **extra_kwargs,
    ) -> List[Dict[str, Any]]:
        """Rerank documents using the configured model function."""
        # Merge extra kwargs with model kwargs
        kwargs = {**self.kwargs, **extra_kwargs}
        return await self.rerank_func(
            query=query, documents=documents, top_n=top_n, **kwargs
        )


class MultiRerankModel(BaseModel):
    """Multiple rerank models for different modes/scenarios."""

    # Primary rerank model (used if mode-specific models are not defined)
    rerank_model: Optional[RerankModel] = None

    # Mode-specific rerank models
    entity_rerank_model: Optional[RerankModel] = None
    relation_rerank_model: Optional[RerankModel] = None
    chunk_rerank_model: Optional[RerankModel] = None

    async def rerank(
        self,
        query: str,
        documents: List[Dict[str, Any]],
        mode: str = "default",
        top_n: Optional[int] = None,
        **kwargs,
    ) -> List[Dict[str, Any]]:
        """Rerank using the appropriate model based on mode."""

        # Select model based on mode
        if mode == "entity" and self.entity_rerank_model:
            model = self.entity_rerank_model
        elif mode == "relation" and self.relation_rerank_model:
            model = self.relation_rerank_model
        elif mode == "chunk" and self.chunk_rerank_model:
            model = self.chunk_rerank_model
        elif self.rerank_model:
            model = self.rerank_model
        else:
            logger.warning(f"No rerank model available for mode: {mode}")
            return documents

        return await model.rerank(query, documents, top_n, **kwargs)


async def generic_rerank_api(
    query: str,
    documents: List[Dict[str, Any]],
    model: str,
    base_url: str,
    api_key: str,
    top_n: Optional[int] = None,
    **kwargs,
) -> List[Dict[str, Any]]:
    """
    Generic rerank function that works with Jina/Cohere compatible APIs.

    Args:
        query: The search query
        documents: List of documents to rerank
        model: Model identifier
        base_url: API endpoint URL
        api_key: API authentication key
        top_n: Number of top results to return
        **kwargs: Additional API-specific parameters

    Returns:
        List of reranked documents with relevance scores
    """
    if not api_key:
        logger.warning("No API key provided for rerank service")
        return documents

    if not documents:
        return documents

    # Prepare documents for reranking - handle both text and dict formats
    prepared_docs = []
    for doc in documents:
        if isinstance(doc, dict):
            # Use 'content' field if available, otherwise use 'text' or convert to string
            text = doc.get("content") or doc.get("text") or str(doc)
        else:
            text = str(doc)
        prepared_docs.append(text)

    # Prepare request
    headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}

    data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}

    if top_n is not None:
        data["top_n"] = min(top_n, len(prepared_docs))

    try:
        async with aiohttp.ClientSession() as session:
            async with session.post(base_url, headers=headers, json=data) as response:
                if response.status != 200:
                    error_text = await response.text()
                    logger.error(f"Rerank API error {response.status}: {error_text}")
                    return documents

                result = await response.json()

                # Extract reranked results
                if "results" in result:
                    # Standard format: results contain index and relevance_score
                    reranked_docs = []
                    for item in result["results"]:
                        if "index" in item:
                            doc_idx = item["index"]
                            if 0 <= doc_idx < len(documents):
                                reranked_doc = documents[doc_idx].copy()
                                if "relevance_score" in item:
                                    reranked_doc["rerank_score"] = item[
                                        "relevance_score"
                                    ]
                                reranked_docs.append(reranked_doc)
                    return reranked_docs
                else:
                    logger.warning("Unexpected rerank API response format")
                    return documents

    except Exception as e:
        logger.error(f"Error during reranking: {e}")
        return documents


async def jina_rerank(
    query: str,
    documents: List[Dict[str, Any]],
    model: str = "BAAI/bge-reranker-v2-m3",
    top_n: Optional[int] = None,
    base_url: str = "https://api.jina.ai/v1/rerank",
    api_key: Optional[str] = None,
    **kwargs,
) -> List[Dict[str, Any]]:
    """
    Rerank documents using Jina AI API.

    Args:
        query: The search query
        documents: List of documents to rerank
        model: Jina rerank model name
        top_n: Number of top results to return
        base_url: Jina API endpoint
        api_key: Jina API key
        **kwargs: Additional parameters

    Returns:
        List of reranked documents with relevance scores
    """
    if api_key is None:
        api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")

    return await generic_rerank_api(
        query=query,
        documents=documents,
        model=model,
        base_url=base_url,
        api_key=api_key,
        top_n=top_n,
        **kwargs,
    )


async def cohere_rerank(
    query: str,
    documents: List[Dict[str, Any]],
    model: str = "rerank-english-v2.0",
    top_n: Optional[int] = None,
    base_url: str = "https://api.cohere.ai/v1/rerank",
    api_key: Optional[str] = None,
    **kwargs,
) -> List[Dict[str, Any]]:
    """
    Rerank documents using Cohere API.

    Args:
        query: The search query
        documents: List of documents to rerank
        model: Cohere rerank model name
        top_n: Number of top results to return
        base_url: Cohere API endpoint
        api_key: Cohere API key
        **kwargs: Additional parameters

    Returns:
        List of reranked documents with relevance scores
    """
    if api_key is None:
        api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")

    return await generic_rerank_api(
        query=query,
        documents=documents,
        model=model,
        base_url=base_url,
        api_key=api_key,
        top_n=top_n,
        **kwargs,
    )


# Convenience function for custom API endpoints
async def custom_rerank(
    query: str,
    documents: List[Dict[str, Any]],
    model: str,
    base_url: str,
    api_key: str,
    top_n: Optional[int] = None,
    **kwargs,
) -> List[Dict[str, Any]]:
    """
    Rerank documents using a custom API endpoint.
    This is useful for self-hosted or custom rerank services.
    """
    return await generic_rerank_api(
        query=query,
        documents=documents,
        model=model,
        base_url=base_url,
        api_key=api_key,
        top_n=top_n,
        **kwargs,
    )


if __name__ == "__main__":
    import asyncio

    async def main():
        # Example usage
        docs = [
            {"content": "The capital of France is Paris."},
            {"content": "Tokyo is the capital of Japan."},
            {"content": "London is the capital of England."},
        ]

        query = "What is the capital of France?"

        result = await jina_rerank(
            query=query, documents=docs, top_n=2, api_key="your-api-key-here"
        )
        print(result)

    asyncio.run(main())
