import requests
import json
import time
from typing import Dict, Any, List, Tuple
import functools
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

class TimeoutError(Exception):
    """Custom timeout exception"""
    pass

class MaxRetriesExceededError(Exception):
    """Custom maximum retry count exceeded exception"""
    pass

def retry_with_timeout(max_retries=5, timeout_seconds=30, backoff_factor=1.0):
    """
    Retry decorator with timeout and retry count limits
    
    :param max_retries: Maximum retry count, default 5 times
    :param timeout_seconds: Timeout duration (seconds), default 30 seconds
    :param backoff_factor: Backoff factor, multiplier for retry interval time, default 1.0 seconds
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            # Use instance configuration, if not available use decorator defaults
            actual_max_retries = getattr(self, 'max_retries', max_retries)
            actual_timeout = getattr(self, 'timeout_seconds', timeout_seconds)
            
            last_exception = None
            
            for attempt in range(actual_max_retries + 1):  # +1 to include the initial attempt
                try:
                    # Set request timeout
                    if 'timeout' not in kwargs:
                        kwargs['timeout'] = actual_timeout
                    
                    start_time = time.time()
                    result = func(self, *args, **kwargs)
                    elapsed_time = time.time() - start_time
                    
                    # Check if timeout (although requests has its own timeout, here we check additionally)
                    if elapsed_time > actual_timeout:
                        raise TimeoutError(f"Operation timeout, took {elapsed_time:.2f} seconds, exceeded limit {actual_timeout} seconds")
                    
                    # Successfully return result
                    if attempt > 0:
                        print(f"✅ Attempt {attempt + 1} succeeded")
                    return result
                    
                except (requests.RequestException, TimeoutError, Exception) as e:
                    last_exception = e
                    
                    if attempt < actual_max_retries:
                        wait_time = backoff_factor * (2 ** attempt)  # Exponential backoff
                        print(f"❌ Attempt {attempt + 1} failed: {e}")
                        print(f"⏱️ Waiting {wait_time:.1f} seconds before retry...")
                        time.sleep(wait_time)
                    else:
                        print(f"❌ All {actual_max_retries + 1} attempts failed")
                        break
            
            # All retries failed, raise exception
            if isinstance(last_exception, TimeoutError):
                raise last_exception
            else:
                raise MaxRetriesExceededError(f"Reached maximum retry count {actual_max_retries}, last error: {last_exception}")
        
        return wrapper
    return decorator

class FaissSearchClient:
    def __init__(self, base_url: str = "http://localhost:8000", max_retries: int = 5, timeout_seconds: int = 300):
        """
        Initialize ANCE FAISS search client
        
        :param base_url: Server address, default http://localhost:8000
        :param max_retries: Maximum retry count, default 5 times
        :param timeout_seconds: Timeout duration (seconds), default 30 seconds
        """
        self.base_url = base_url.rstrip('/')
        self.session = requests.Session()
        self.max_retries = max_retries
        self.timeout_seconds = timeout_seconds
        
        # Configure connection pool - significantly increase pool size to fully utilize server's high concurrency capability
        adapter = HTTPAdapter(
            pool_connections=256,  # Number of connection pools, support more concurrency
            pool_maxsize=2000,     # Maximum connections per pool, significantly increased
            max_retries=0,        # No retry at adapter level, retry handled at upper level
            pool_block=False      # Don't block when pool is full, create new connections instead
        )
        
        self.session.mount('http://', adapter)
        self.session.mount('https://', adapter)
        
        # Set default timeout for session
        self.session.timeout = timeout_seconds
        
        print(f"✅ ANCE search client initialized, connection pool size: 1500, supports high concurrency access")
    
    def update_retry_settings(self, max_retries: int = None, timeout_seconds: int = None):
        """
        Update retry settings
        
        :param max_retries: Maximum retry count
        :param timeout_seconds: Timeout duration (seconds)
        """
        if max_retries is not None:
            self.max_retries = max_retries
        if timeout_seconds is not None:
            self.timeout_seconds = timeout_seconds
            self.session.timeout = timeout_seconds
        
        print(f"🔧 Retry settings updated: max_retries={self.max_retries}, timeout_seconds={self.timeout_seconds}")
    
    def get_retry_settings(self) -> Dict[str, int]:
        """
        Get current retry settings
        
        :return: Dictionary containing retry settings
        """
        return {
            "max_retries": self.max_retries,
            "timeout_seconds": self.timeout_seconds
        }
    
    def health_check(self) -> Dict[str, Any]:
        """
        Check service health status
        
        :return: Health check result
        """
        try:
            response = self.session.get(f"{self.base_url}/health", timeout=self.timeout_seconds)
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            print(f"❌ Health check failed: {e}")
            return {"status": "error", "error": str(e)}
    
    @retry_with_timeout()
    def _batch_search_with_retry(self, queries: List[str], top_k: int = 5, threads: int = 8, timeout: int = None) -> Tuple[List[List[Tuple[int, float]]], float]:
        """
        Internal implementation of batch search with retry
        
        :param queries: Query list
        :param top_k: Return top k most similar results
        :param threads: Number of threads to use
        :param timeout: Request timeout duration
        :return: (Search results, search time) where search results is a list of (doc_id, score) lists for each query
        """
        payload = {
            "queries": queries,
            "top_k": top_k,
            "threads": threads
        }
        
        response = self.session.post(
            f"{self.base_url}/search",
            json=payload,
            headers={"Content-Type": "application/json"},
            timeout=timeout or self.timeout_seconds
        )
        response.raise_for_status()
        
        data = response.json()
        
        # Convert result format to original format (doc_id, score) tuples
        results = []
        for query_results in data["results"]:
            query_tuples = [(result["doc_id"], result["score"]) for result in query_results]
            results.append(query_tuples)
        
        return results, data["search_time"]
    
    def batch_search(self, queries: List[str], top_k: int = 5, threads: int = 8) -> Tuple[List[List[Tuple[int, float]]], float]:
        """
        Batch search
        
        :param queries: Query list
        :param top_k: Return top k most similar results
        :param threads: Number of threads to use
        :return: (Search results, search time) where search results is a list of (doc_id, score) lists for each query
        """
        try:
            return self._batch_search_with_retry(queries, top_k, threads)
        except (MaxRetriesExceededError, TimeoutError) as e:
            print(f"❌ Batch search finally failed: {e}")
            raise
        except Exception as e:
            print(f"❌ Batch search request failed: {e}")
            raise
    
    def search(self, query: str, top_k: int = 5, threads: int = 8) -> Tuple[List[Tuple[int, float]], float]:
        """
        Single query search (internally calls batch_search)
        
        :param query: Single query string
        :param top_k: Return top k most similar results
        :param threads: Number of threads to use
        :return: (Search results, search time) where search results is a (doc_id, score) list
        """
        results, search_time = self.batch_search([query], top_k, threads)
        return results[0], search_time

def cal_h(score_list):
    res = 0
    for score in score_list:
        if score == 0:
            score = 0.0001
        res += 1 / score
    return 1 / res
    # return len(score_list) / res

def rerank3(results_init, top_k) -> list[int]:
    '''
    args:
    results:list[list[id, score]]
    return:
    Directly return the document id, representing the rank, e.g., [123, 456, 789]

    For dense retrieval (bge, ance) with HNSW Index, the smaller the better
    For sparse retrieval (bm25) and dense retrieval with Flat Index, the larger the better, so currently try to take the reciprocal directly
    '''
    # If bm25 retriever, take reciprocal
    # Here we need to copy it, avoid modifying the original
    retrieval = 'ance'
    results = results_init
    if retrieval == 'bm25':
        results = [row.copy() for row in results_init]
        for result in results:
            for i in range(len(result)):
                idx, score = result[i]
                result[i] = [idx, 1 / score]

    document_hash = {}
    for i, result in enumerate(results): # For each document,
        for doc_id, dis in result:
            if doc_id not in document_hash:
                # The first dimension represents sum(dis), the second dimension represents whether there is i
                # Change to list, representing the score of each
                document_hash[doc_id] = [] 
            # Add distance
            document_hash[doc_id].append(dis)
    docs_score = {}
    for doc_id, value_list in document_hash.items():
        score = cal_h(value_list)
        docs_score[doc_id] = score
    sorted_items = sorted(docs_score.items(), key=lambda x: x[1])[:top_k]
    # sorted_items = sorted(document_hash.items(), key=lambda x: x[1][0])[:top_k]
    pred_results = [item[0] for item in sorted_items]
    # print("results0: ", results[0])
    # print("resutls1: ", results[1])
    print("merge_results: ", sorted_items)
    return pred_results

def cal_h3(score_list): # Top 10
    res = 0
    mn_score = 100000
    for score, rank in score_list:
        res += 1 / rank
        mn_score = min(mn_score, score)
    return 1 / res, mn_score

def rerank5(results_init, top_k) -> list[int]:
    '''
    args:
    results:list[list[id, score]]
    return:
    Directly return the document id, representing the rank, e.g., [123, 456, 789]

    For dense retrieval (bge, ance) with HNSW Index, the smaller the better
    For sparse retrieval (bm25) and dense retrieval with Flat Index, the larger the better, so currently try to take the reciprocal directly
    '''
    # If bm25 retriever, take reciprocal
    # Here we need to copy it, avoid modifying the original
    # results = results_init
    results = [row.copy() for row in results_init]
    retrieval = 'ance'
    USE_HNSW = '1'
    if retrieval == 'bm25' or USE_HNSW == '0':
        # results = [row.copy() for row in results_init]
        for result in results:
            for i in range(len(result)):
                idx, score = result[i]
                result[i] = [idx, 1 / score]
    # bias
    for result in results:
        mx = 0
        for i in range(0, len(result) - 1, 1):
            _, score = result[i]
            _, score2 = result[i + 1]
            mx = max(mx, score2 - score)
        _, score = result[0]
        _, score2 = result[-1]
        mx = score2 - score
        for i in range(len(result)):
            idx, score = result[i]
            result[i] = [idx, score - mx]

    document_hash = {}
    for i, result in enumerate(results): # For each document,
        rank = 0
        for doc_id, dis in result:
            rank += 1
            if doc_id not in document_hash:
                # The first dimension represents sum(dis), the second dimension represents whether there is i
                # Change to list, representing the score of each
                document_hash[doc_id] = [] 
            # Add distance
            # document_hash[doc_id].append(dis)
            # Add distance and rank
            document_hash[doc_id].append((dis, rank))
    docs_score = {}
    for doc_id, value_list in document_hash.items():
        rank, score = cal_h3(value_list)
        docs_score[doc_id] = (rank, score)
    sorted_items = sorted(docs_score.items(), key=lambda x: (x[1][0], x[1][1]))[:top_k]
    # sorted_items = sorted(document_hash.items(), key=lambda x: x[1][0])[:top_k]
    pred_results = [item[0] for item in sorted_items]
    print("results0: ", results[0])
    print("resutls1: ", results[1])
    print("merge_results: ", sorted_items)
    return pred_results


def rerank1(results_init, top_k) -> list[int]:
    '''
    args:
    results:list[list[id, score]]
    return:
    Directly return the document id, representing the rank, e.g., [123, 456, 789]

    For dense retrieval (bge, ance) with HNSW Index, the smaller the better
    For sparse retrieval (bm25) and dense retrieval with Flat Index, the larger the better, so currently try to take the reciprocal directly
    '''
    # If bm25 retriever, take reciprocal
    retrieval = 'ance'
    results = results_init
    if retrieval == 'bm25':
        results = [row.copy() for row in results_init]
        for result in results:
            for i in range(len(result)):
                idx, score = result[i]
                result[i] = [idx, 1 / score]

    document_hash = {}
    for i, result in enumerate(results): # For each document,
        for doc_id, dis in result:
            if doc_id not in document_hash:
                # The first dimension represents sum(dis), the second dimension represents whether there is i
                document_hash[doc_id] = [0, 0] 
            cur = [dis, 1 << i]
            # This was wrong at first
            document_hash[doc_id][0] += cur[0]
            document_hash[doc_id][1] += 1 << i
    mx = [result[-1][1] for result in results] # Each maximum value
    # Take the document with the smallest distance when summed across different queries. If a document is not recalled in a query, use the maximum distance in top_k.
    for doc_id, value in document_hash.items():
        dis, has_i = value
        for i in range(len(results)):
            if not (has_i & 1 << i):
                value[0] += mx[i]
                value[1] += 1 << i
    sorted_items = sorted(document_hash.items(), key=lambda x: x[1][0])[:top_k]
    pred_results = [item[0] for item in sorted_items]
    # print("results0: ", results[0])
    # print("resutls1: ", results[1])
    print("merge_results: ", sorted_items)
    return pred_results

def test_client():
    """Test client functionality"""
    # Create client
    client = FaissSearchClient(base_url="http://10.0.128.190:8000")

    # hp
    # client = FaissSearchClient(base_url="http://10.0.128.56:8002")
    
    # Health check
    print("🔍 Checking ANCE search service status...")
    health = client.health_check()
    print(f"Service status: {health}")
    
    if health.get("status") != "healthy":
        print("❌ Service not ready, please start ANCE search server first")
        return
    
    # Test queries
    queries = [
        "Random House Tower building",
        "888 7th Avenue skyscraper",
        # "Are Random House Tower and 888 7th Avenue both used for real estate?"
    ]
    gt = [16605491, 17626850]
    queries = [
        "Force India driver born in 1990",
        "Sergio Pérez, Force India driver",
        # "Which other Mexican Formula One race car driver has held the podium besides the Force India driver born in 1990?"
    ]
    gt = [38828650, 19801645]
    queries = [
        'The Livesey Hal War Memorial',
        ' World War II with over 60 million casualties',
        # "The Livesey Hal War Memorial commemorates the fallen of which war, that had over 60 million casualties?"
    ]
    gt = [35527133, 240900]

    queries = [
        "what economy of urk netherlands depends on",
        "urk Netherlands economy"
    ]
    gt = [25289154]
    queries = [
        # "what industries drives economy of Urk, Netherlands",
        # "Urk, Netherlands economy relies on agriculture and tourism"
        # "Billy King was an Australian rules footballer who participated in a game that was contested between the South Melbourne Football Club and Carlton Football Club, and was held at what location in Melbourne in 1945?"
        "what does the economy of this place depend upon"
    ]
    gt = [25289154]
    # gt = [19328903, 6924192]
    print(f"\n🔍 Starting ANCE batch search for {len(queries)} queries...")
    
    try:
        # Execute batch search
        time_start = time.time()
        top_results, search_time = client.batch_search(queries, top_k=100)
        print(top_results)
        for i, r in enumerate(top_results[0], 1):
            idx, score = r
            if idx in gt:
                print("rank: ", i)
        # res1 = rerank1(top_results, top_k=100)
        # print("res1: ")
        # for i, item in enumerate(res1, 1):
        #     for g in gt:
        #         if g == item:
        #             print(i, item)
        # res3 = rerank3(top_results, top_k=100)
        # print("res3: ")
        # for i, item in enumerate(res3, 1):
        #     for g in gt:
        #         if g == item:
        #             print(i, item)

        # res5 = rerank5(top_results, top_k=100)
        # print("res5: ")
        # for i, item in enumerate(res5, 1):
        #     for g in gt:
        #         if g == item:
        #             print(i, item)
        # time_end = time.time()
        
        # print(f"⏱️ Total time: {time_end - time_start:.4f} seconds")
        # print(f"⏱️ Server search time: {search_time:.4f} seconds")
        
        # # Display results
        # score_sum = []
        # print("\n🔍 ANCE Top Batch Search Results:")
        # for i, query in enumerate(queries):
        #     print(f"\n🔎 Query: {query}")
        #     cnt = 0
        #     cur_score = 0
        #     for doc_id, score in top_results[i]:
        #         cnt += 1
        #         cur_score += score
        #         print(f"📄 Doc ID: {doc_id} | {cnt}🔢 ANCE Score: {score:.4f}")
        #     print("Total score sum: ", cur_score)
                
    except Exception as e:
        print(f"❌ Search failed: {e}")

def test_single_query():
    """Test single query"""
    client = FaissSearchClient()
    
    query = "what is the numerical value of the speed of light in a vacuum"
    print(f"\n🔍 ANCE single query test: {query}")
    
    try:
        results, search_time = client.search(query, top_k=5)
        print(f"⏱️ Search time: {search_time:.4f} seconds")
        
        print("\n🔍 ANCE search results:")
        for doc_id, score in results:
            print(f"📄 Doc ID: {doc_id} | 🔢 ANCE Score: {score:.4f}")
            
    except Exception as e:
        print(f"❌ Single query failed: {e}")

def test_retry_functionality():
    """Test retry functionality"""
    print("\n=== Testing ANCE search retry functionality ===")
    
    # Create client with shorter timeout and fewer retries for quick testing
    client = FaissSearchClient(
        base_url="http://localhost:8000",
        max_retries=3,
        timeout_seconds=10
    )
    
    print(f"📋 Current retry settings: {client.get_retry_settings()}")
    
    # Test health check (usually this succeeds)
    print("\n🔍 Testing health check...")
    health = client.health_check()
    print(f"Health check result: {health}")
    
    if health.get("status") != "healthy":
        print("❌ Service not ready, cannot test retry functionality")
        print("💡 Can try connecting to wrong address to test retry mechanism:")
        
        # Test retry mechanism with wrong address
        test_client = FaissSearchClient(
            base_url="http://localhost:9999",  # Wrong port
            max_retries=2,
            timeout_seconds=5
        )
        
        print(f"\n🔧 Testing retry mechanism with wrong address...")
        try:
            results, search_time = test_client.search("test query")
            print("Unexpectedly succeeded!")
        except (MaxRetriesExceededError, TimeoutError) as e:
            print(f"✅ Retry mechanism working properly: {e}")
        except Exception as e:
            print(f"❌ Other error: {e}")
        
        return
    
    # Test normal search (should succeed)
    print("\n🔍 Testing normal search...")
    try:
        query = "what is the speed of light"
        results, search_time = client.search(query, top_k=3)
        print(f"✅ Search succeeded, took {search_time:.4f} seconds")
        print(f"📄 Found {len(results)} results")
    except Exception as e:
        print(f"❌ Search failed: {e}")
    
    # Update retry settings
    print("\n🔧 Updating retry settings...")
    client.update_retry_settings(max_retries=2, timeout_seconds=15)
    
    # Test batch search
    print("\n🔍 Testing batch search (new settings)...")
    try:
        queries = ["test query 1", "test query 2"]
        results, search_time = client.batch_search(queries, top_k=2)
        print(f"✅ Batch search succeeded, took {search_time:.4f} seconds")
        print(f"📄 Processed {len(results)} queries")
    except Exception as e:
        print(f"❌ Batch search failed: {e}")

if __name__ == "__main__":
    print("=== ANCE FAISS Search Client Test ===")
    
    # Test batch search
    test_client()
    
    # Test single query
    # test_single_query()
    
    # Test retry functionality
    # test_retry_functionality() 