import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException, Timeout
from typing import List, Tuple, Dict, Any
import time

# Custom exception classes
class MaxRetriesExceededError(Exception):
    """Raised when maximum retry count is reached"""
    pass

# Retry decorator
def retry_with_timeout(max_retries: int = 5, backoff_factor: float = 0.1):
    """
    Decorator with timeout and retry
    
    :param max_retries: Maximum retry count
    :param backoff_factor: Backoff factor, retry interval = backoff_factor * (2 **(retry_count - 1))
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            last_exception = None
            timeout = kwargs.get('timeout') or args[0].timeout_seconds  # Get default timeout from instance
            
            for attempt in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except (Timeout, RequestException) as e:
                    last_exception = e
                    # If last attempt fails, no more retries
                    if attempt == max_retries - 1:
                        break
                    # Calculate backoff time
                    sleep_time = backoff_factor * (2** attempt)
                    time.sleep(sleep_time)
            
            if isinstance(last_exception, Timeout):
                raise TimeoutError(f"Request timeout (tried {max_retries} times)") from last_exception
            else:
                raise MaxRetriesExceededError(f"Reached maximum retry count {max_retries} times") from last_exception
        return wrapper
    return decorator

class LuceneSearchClient:
    def __init__(self, base_url: str = "http://localhost:8000", max_retries: int = 5, timeout_seconds: int = 30):
        """
        Initialize Pyserini batch retrieval client
        
        :param base_url: Server base URL
        :param max_retries: Maximum retry count
        :param timeout_seconds: Default timeout duration (seconds)
        """
        self.base_url = base_url.rstrip('/')
        self.max_retries = max_retries
        self.timeout_seconds = timeout_seconds
        
        # Initialize session and connection pool
        self.session = requests.Session()
        adapter = HTTPAdapter(
            pool_connections=128,
            pool_maxsize=1000,
            max_retries=0,  # Retry handled by decorator
            pool_block=False
        )
        self.session.mount('http://', adapter)
        self.session.mount('https://', adapter)

    @retry_with_timeout()
    def _batch_search_with_retry(self, queries: List[str], top_k: int = 5, threads: int = 8, timeout: int = None) -> Tuple[List[List[Tuple[int, float]]], float]:
        """
        Internal implementation of batch search with retry
        
        :param queries: Query list
        :param top_k: Return top k most similar results
        :param threads: Number of threads to use
        :param timeout: Request timeout duration (overrides default)
        :return: (Search results, search time) where search results is a list of (doc_id, score) lists for each query
        """
        payload = {
            "queries": queries,
            "top_k": top_k,
            "threads": threads
        }
        
        response = self.session.post(
            f"{self.base_url}/search",
            json=payload,
            headers={"Content-Type": "application/json"},
            timeout=timeout or self.timeout_seconds
        )
        response.raise_for_status()  # Raise HTTP error status codes
        
        data = response.json()
        
        # Convert result format to (doc_id, score) tuple list
        results = []
        # print(data["results"])
        # raise Exception()
        # Here data["results"] is in dictionary format,
        # for query, results in search_results.items():
        #     for idx, text, score in results:
        # We only need id and score

        # The default return is string type, directly convert to int type
        for query, query_results in data["results"].items():
            query_tuples = [(int(idx), score) for idx, text, score in query_results]
            results.append(query_tuples)
            # print(query_results)
        return results, data["search_time"]
    
    def batch_search(self, queries: List[str], top_k: int = 5, threads: int = 8) -> Tuple[List[List[Tuple[int, float]]], float]:
        """
        Batch search
        
        :param queries: Query list
        :param top_k: Return top k most similar results
        :param threads: Number of threads to use
        :return: (Search results, search time) where search results is a list of (doc_id, score) lists for each query
        """
        try:
            return self._batch_search_with_retry(
                queries=queries,
                top_k=top_k,
                threads=threads
            )
        except (MaxRetriesExceededError, TimeoutError) as e:
            print(f"❌ Batch search finally failed: {e}")
            raise
        except Exception as e:
            print(f"❌ Batch search request failed: {e}")
            raise
    
    def search(self, query: str, top_k: int = 5, threads: int = 8) -> Tuple[List[Tuple[int, float]], float]:
        """
        Single query search (internally calls batch_search)
        
        :param query: Single query string
        :param top_k: Return top k most similar results
        :param threads: Number of threads to use
        :return: (Search results, search time) where search results is a (doc_id, score) list
        """
        # if isinstance(query, str):
        #     query = [query]    
        results, search_time = self.batch_search([query], top_k, threads)
        return results[0], search_time

    def health_check(self) -> Dict[str, Any]:
        """Check service health status"""
        try:
            response = self.session.get(f"{self.base_url}/health", timeout=self.timeout_seconds)
            response.raise_for_status()
            return response.json()
        except Exception as e:
            return {"status": "unhealthy", "error": str(e)}
    

def cal_h(score_list):
    res = 0
    for score in score_list:
        if score == 0:
            score = 0.0001
        res += 1 / score
    return 1 / res
    # return len(score_list) / res

def rerank(results, top_k) -> list[int]:
    '''
    args:
    results:list[list[id, score]]
    return:
    Directly return the document id, representing the rank, e.g., [123, 456, 789]

    For dense retrieval (bge, ance) with HNSW Index, the smaller the better
    For sparse retrieval (bm25) and dense retrieval with Flat Index, the larger the better, so currently try to take the reciprocal directly
    '''
    # If bm25 retriever, take reciprocal
    retrieval = 'bm25'
    if retrieval == 'bm25':
        for result in results:
            for i in range(len(result)):
                idx, score = result[i]
                result[i] = [idx, 1 / score]

    document_hash = {}
    for i, result in enumerate(results): # For each document,
        for doc_id, dis in result:
            if doc_id not in document_hash:
                # The first dimension represents sum(dis), the second dimension represents whether there is i
                # Change to list, representing the score of each
                document_hash[doc_id] = [] 
            # Add distance
            document_hash[doc_id].append(dis)
    docs_score = {}
    for doc_id, value_list in document_hash.items():
        score = cal_h(value_list)
        docs_score[doc_id] = score
    sorted_items = sorted(docs_score.items(), key=lambda x: x[1])[:top_k]
    # sorted_items = sorted(document_hash.items(), key=lambda x: x[1][0])[:top_k]
    pred_results = [item[0] for item in sorted_items]
    return pred_results

if __name__ == "__main__":
    # Initialize client
    client = LuceneSearchClient(base_url="http://10.0.128.190:8001")
    # client = LuceneSearchClient(base_url="http://10.0.128.56:8003")
    
    # Health check
    print("Health status:", client.health_check())
    
    # Single query search
    query = "Billy King was an Australian rules footballer who participated in a game that was contested between the South Melbourne Football Club and Carlton Football Club, and was held at what location in Melbourne in 1945?"
    query = "Billy King Australian rules footballer South Melbourne vs Carlton 1945 match location"
    query = "what does the economy of this place depend upon"
    # print("query: ", query)
    # query = "2132312312"
    results, search_time = client.search(query, top_k=100)
    print(f"\nQuery: {query} (took: {search_time:.2f} seconds)")
    print(results)
    gt = [25289154]
    # gt = [19328903, 6924192]
    for i, item in enumerate(results, 1):
        idx, score = item
        if idx in gt:
            print("rank: ", i)

    # # Batch query search
    # queries = ["environmentally friendly materials", "sustainable development"]
    queries = [
        # "United Express Flight 3411 on April 9, 2017",
        "Express airline"
    ]
    # queries = [
    #     "WAGS Atlanta, an adult entertainment venue",
    #     "WAGS, adult entertainment magazine"
    # ]
    # batch_results, batch_time = client.batch_search(queries, top_k=100)
    # res = rerank(batch_results, top_k=100)
    # print(res)
    # print(f"\nTotal batch search time: {batch_time:.2f} seconds")