import json
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import os
import time
import argparse
from fastapi import FastAPI, HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel
from typing import List, Tuple
import uvicorn
import asyncio
from collections import defaultdict
import threading

def load_json(json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        _, ext = os.path.splitext(json_path)
        if ext == '.json':
            data = json.load(f)
            return data
        if ext == '.jsonl':
            data = []
            for item in f:
                data.append(json.loads(item))
            return data
        print("Wrong path, please provide .json/.jsonl file")

def save_json(save_path, json_list):
    with open(save_path, 'w', encoding='utf-8') as f:
        _, ext = os.path.splitext(save_path)
        try: 
            if ext == '.json':
                json.dump(json_list, f, ensure_ascii=False, indent=4)
                print(f"Successfully written to {save_path}")
            elif ext == '.jsonl':
                for item in json_list:
                    f.write(json.dumps(item, ensure_ascii=False) + '\n')
                print(f"Successfully written to {save_path}")
            else:
                raise Exception
        except:
            print("Wrong path, please provide .json/.jsonl file")

def average_pool(last_hidden_states, attention_mask):
    """Average pooling for ANCE model"""
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

class FaissHNSWSearcher:
    def __init__(self, model_name, index_path, doc_ids_path, device="auto", cache_size=50000):
        """
        Initialize the FAISS HNSW Searcher for ANCE.
        
        :param model_name: Name of the ANCE transformer model for embeddings.
        :param index_path: Path to the FAISS index file.
        :param doc_ids_path: Path to the saved document IDs file.
        :param device: Device to load the model ("auto", "cpu", or "cuda").
        :param cache_size: Size of the embedding cache.
        """
        self.model_name = model_name
        self.index_path = index_path
        self.doc_ids_path = doc_ids_path
        self.device = device
        self.cache_size = cache_size
        
        # Add embedding cache mechanism
        self.embedding_cache = {}
        self.cache_order = []
        
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name, device_map=self.device)
        
        # Load FAISS index
        self.index = self._load_faiss_index()
        
        # Load document IDs
        self.doc_ids = self._load_doc_ids()
        
        print(f"✅ ANCE FAISS searcher initialized successfully, embedding cache size: {self.cache_size}")
    
    def _load_faiss_index(self):
        """ Load the FAISS HNSW index from disk."""
        if not os.path.exists(self.index_path):
            raise FileNotFoundError(f"FAISS index file not found at {self.index_path}")
        print(f"Loading FAISS index from {self.index_path}...")
        return faiss.read_index(self.index_path)
    
    def _load_doc_ids(self):
        """ Load document IDs from disk."""
        if not os.path.exists(self.doc_ids_path):
            raise FileNotFoundError(f"Document IDs file not found at {self.doc_ids_path}")
        print(f"Loading document IDs from {self.doc_ids_path}...")
        return np.load(self.doc_ids_path)
    
    def _get_cached_embedding(self, text):
        """Get or compute embedding (with cache)"""
        # Check cache
        if text in self.embedding_cache:
            # Update cache order
            self.cache_order.remove(text)
            self.cache_order.append(text)
            return self.embedding_cache[text]
        
        # Compute new embedding
        embedding = self.get_dense_embedding(text)
        
        # Add to cache
        self.embedding_cache[text] = embedding
        self.cache_order.append(text)
        
        # Clean cache
        if len(self.embedding_cache) > self.cache_size:
            oldest_text = self.cache_order.pop(0)
            del self.embedding_cache[oldest_text]
        
        return embedding
    
    def get_dense_embedding(self, text):
        """ Compute dense embedding for a query using the ANCE model. """
        # ANCE model does not need special prompt prefix, use original text directly
        batch_dict = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512).to(self.model.device)
        with torch.no_grad():
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = self.model(**batch_dict)
            # For ANCE model, use average pooling
            embedding = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
            # L2 normalization
            embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
        return embedding.cpu().numpy()
    
    def search(self, query, top_k=5, threads=8):
        """
        Perform a search using the FAISS HNSW index.
        
        :param query: Input query string.
        :param top_k: Number of nearest neighbors to retrieve.
        :return: List of (doc_id, similarity_score) tuples.
        """
        query_embedding = self._get_cached_embedding(query).astype(np.float32)
        
        # Normalize for cosine similarity
        faiss.normalize_L2(query_embedding)

        faiss.omp_set_num_threads(threads)
        
        D, I = self.index.search(query_embedding, top_k)  # D = similarity scores, I = doc indices
        
        # Retrieve document IDs
        results = [(self.doc_ids[idx], D[0][i]) for i, idx in enumerate(I[0])]
        return results

    def batch_search(self, queries, top_k=5, threads=8):
        """
        Perform batch search for multiple queries using the FAISS HNSW index.
        
        :param queries: List of input query strings.
        :param top_k: Number of nearest neighbors to retrieve.
        :param threads: Number of threads to use for FAISS search.
        :return: List of lists containing (doc_id, similarity_score) tuples for each query.
        """
        # Use cache mechanism to batch get embeddings
        query_embeddings = np.vstack([self._get_cached_embedding(query).astype(np.float32) for query in queries])
        
        # Normalize for cosine similarity
        faiss.normalize_L2(query_embeddings)
        
        faiss.omp_set_num_threads(threads)
        
        D, I = self.index.search(query_embeddings, top_k)  # D = similarity scores, I = doc indices
        
        # Retrieve document IDs for each query
        results = [[(self.doc_ids[idx], D[i][j]) for j, idx in enumerate(I[i])] for i in range(len(queries))]
        return results

# Configuration class
class Config:
    def __init__(self):
        # Default configuration - ANCE specific paths
        self.model_name = 'sentence-transformers/msmarco-roberta-base-ance-firstp'
        self.output_dir = "local_index_search/wikipedia_segment/ance"
        self.host = "0.0.0.0"
        self.port = 8000
        self.device = "auto"
        
        # Read configuration from environment variables
        self.model_name = os.getenv('MODEL_NAME', self.model_name)
        self.output_dir = os.getenv('OUTPUT_DIR', self.output_dir)
        self.host = os.getenv('HOST', self.host)
        self.port = int(os.getenv('PORT', self.port))
        self.device = os.getenv('DEVICE', self.device)
        
        # Construct paths - ANCE specific file names
        self.index_path = os.path.join(self.output_dir, 'ance_faiss_index.bin')
        self.doc_ids_path = os.path.join(self.output_dir, 'doc_ids.npy')
    
    def from_args(self, args):
        """Update configuration from command line arguments (only override user explicitly set parameters)"""
        if args.model_name is not None:
            self.model_name = args.model_name
        if args.output_dir is not None:
            self.output_dir = args.output_dir
            self.index_path = os.path.join(self.output_dir, 'ance_faiss_index.bin')
            self.doc_ids_path = os.path.join(self.output_dir, 'doc_ids.npy')
        if args.host is not None:
            self.host = args.host
        if args.port is not None:
            self.port = args.port
        if args.device is not None:
            self.device = args.device
    
    def print_config(self):
        """Print current configuration"""
        print("=== ANCE Search Server Configuration ===")
        print(f"Model path: {self.model_name}")
        print(f"Index directory: {self.output_dir}")
        print(f"Index file: {self.index_path}")
        print(f"Document ID file: {self.doc_ids_path}")
        print(f"Device: {self.device}")
        print(f"Host: {self.host}")
        print(f"Port: {self.port}")
        print("==========================")

# Concurrency monitoring middleware
class ConcurrencyMonitorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, max_concurrent_requests: int = 5000):
        super().__init__(app)
        self.max_concurrent_requests = max_concurrent_requests
        self.current_requests = 0
        self.lock = threading.Lock()
        self.request_stats = defaultdict(int)
    
    async def dispatch(self, request: Request, call_next):
        # Check concurrency count
        with self.lock:
            if self.current_requests >= self.max_concurrent_requests:
                return Response(
                    content="Too many concurrent requests",
                    status_code=429,
                    headers={"Retry-After": "1"}
                )
            self.current_requests += 1
            self.request_stats[request.url.path] += 1
        
        try:
            start_time = time.time()
            response = await call_next(request)
            process_time = time.time() - start_time
            
            # Add processing time to response headers
            response.headers["X-Process-Time"] = str(process_time)
            response.headers["X-Current-Requests"] = str(self.current_requests)
            
            return response
        finally:
            with self.lock:
                self.current_requests -= 1

# Global concurrency monitoring instance
monitor = ConcurrencyMonitorMiddleware(None, max_concurrent_requests=5000)

# FastAPI application
app = FastAPI(title="ANCE FAISS HNSW Search Service", version="1.0.0")

# Add middleware
app.add_middleware(ConcurrencyMonitorMiddleware, max_concurrent_requests=5000)

# Request models
class SearchRequest(BaseModel):
    queries: List[str]
    top_k: int = 5
    threads: int = 8

class SearchResult(BaseModel):
    doc_id: int
    score: float

class SearchResponse(BaseModel):
    results: List[List[SearchResult]]
    search_time: float

# Global variables
searcher = None
config = None

def initialize_searcher():
    """Initialize searcher"""
    global searcher, config
    
    try:
        config.print_config()
        searcher = FaissHNSWSearcher(
            model_name=config.model_name, 
            index_path=config.index_path, 
            doc_ids_path=config.doc_ids_path,
            device=config.device,
            cache_size=50000  # Set large cache to reduce repeated computation
        )
        print("✅ ANCE searcher initialized successfully")
    except Exception as e:
        print(f"❌ ANCE searcher initialization failed: {e}")
        raise

@app.on_event("startup")
async def startup_event():
    """Initialize searcher when application starts"""
    initialize_searcher()

@app.get("/")
async def root():
    return {"message": "ANCE FAISS HNSW Search Service is running"}

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    if searcher is None:
        raise HTTPException(status_code=503, detail="Searcher not initialized")
    return {"status": "healthy", "searcher_initialized": True}

@app.get("/stats")
async def get_stats():
    """Get server statistics"""
    if searcher is None:
        raise HTTPException(status_code=503, detail="Searcher not initialized")
    
    # Get concurrency monitoring information
    current_requests = 0
    request_stats = {}
    
    # Iterate through all middleware instances to find monitoring middleware
    for middleware in app.user_middleware:
        if isinstance(middleware.cls, type) and issubclass(middleware.cls, ConcurrencyMonitorMiddleware):
            # Here we need to access the actual middleware instance
            # Due to architectural limitations, we create a simple statistics endpoint
            break
    
    return {
        "status": "healthy",
        "searcher_initialized": True,
        "current_requests": current_requests,
        "max_concurrent_requests": 5000,
        "request_stats": request_stats,
        "cache_stats": {
            "cache_size": len(searcher.embedding_cache),
            "cache_capacity": searcher.cache_size,
            "cache_usage": len(searcher.embedding_cache) / searcher.cache_size if searcher.cache_size > 0 else 0,
            "cached_queries_sample": list(searcher.embedding_cache.keys())[:10] if searcher.embedding_cache else []
        },
        "timestamp": time.time()
    }

@app.get("/config")
async def get_config():
    """Get current configuration"""
    if config is None:
        raise HTTPException(status_code=503, detail="Config not initialized")
    
    return {
        "model_name": config.model_name,
        "output_dir": config.output_dir,
        "index_path": config.index_path,
        "doc_ids_path": config.doc_ids_path,
        "device": config.device,
        "host": config.host,
        "port": config.port
    }

@app.post("/clear_cache")
async def clear_cache():
    """Clear embedding cache"""
    if searcher is None:
        raise HTTPException(status_code=503, detail="Searcher not initialized")
    
    old_cache_size = len(searcher.embedding_cache)
    searcher.embedding_cache.clear()
    searcher.cache_order.clear()
    
    return {
        "message": "Cache cleared successfully",
        "cleared_items": old_cache_size,
        "current_cache_size": len(searcher.embedding_cache),
        "timestamp": time.time()
    }

@app.post("/search", response_model=SearchResponse)
async def batch_search(request: SearchRequest):
    """Batch search endpoint"""
    if searcher is None:
        raise HTTPException(status_code=503, detail="Searcher not initialized")
    
    if not request.queries:
        raise HTTPException(status_code=400, detail="Queries list cannot be empty")
    
    try:
        time_start = time.time()
        results = searcher.batch_search(
            queries=request.queries,
            top_k=request.top_k,
            threads=request.threads
        )
        time_end = time.time()
        search_time = time_end - time_start
        
        # Convert result format
        formatted_results = []
        for query_results in results:
            query_formatted = [
                SearchResult(doc_id=int(doc_id), score=float(score))
                for doc_id, score in query_results
            ]
            formatted_results.append(query_formatted)
        
        return SearchResponse(
            results=formatted_results,
            search_time=search_time
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="ANCE FAISS HNSW Search Service")
    parser.add_argument("--model-name", type=str, 
                       help="Path to the ANCE transformer model (default: from env MODEL_NAME or built-in default)")
    parser.add_argument("--output-dir", type=str, 
                       help="Directory containing FAISS index and doc IDs (default: from env OUTPUT_DIR or built-in default)")
    parser.add_argument("--host", type=str, 
                       help="Host to bind the server (default: from env HOST or 0.0.0.0)")
    parser.add_argument("--port", type=int, 
                       help="Port to bind the server (default: from env PORT or 8000)")
    parser.add_argument("--device", type=str, 
                       choices=["auto", "cpu", "cuda"], 
                       help="Device to use (default: from env DEVICE or auto)")
    return parser.parse_args()

if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    
    # Initialize configuration
    config = Config()
    config.from_args(args)
    
    # Run server - significantly increase concurrency limits to fully utilize GPU performance
    uvicorn.run(
        app, 
        host=config.host, 
        port=config.port,
        log_level="info",
        # Significantly increase concurrency processing capability
        workers=1,  # Single process, avoid duplicate model loading
        limit_concurrency=10000,  # Allow 10000 concurrent connections
        # limit_max_requests=None,  # Completely remove maximum request count limit
        timeout_keep_alive=60,  # Increase keep-alive time
        timeout_graceful_shutdown=60,  # Increase graceful shutdown timeout
        # Performance optimization
        loop="asyncio",  # Use asyncio event loop
        http="httptools",  # Use httptools parser
        ws="websockets",  # websocket support
        lifespan="on",  # Enable lifespan events
        access_log=False,  # Disable access logs to improve performance
        # Connection configuration
        backlog=4096,  # Increase backlog size
        # Additional performance optimization
        h11_max_incomplete_event_size=16 * 1024 * 1024,  # 16MB
    ) 