import json
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import os
import time
import argparse
from fastapi import FastAPI, HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel
from typing import List, Tuple, Optional
import uvicorn
import asyncio
from collections import defaultdict
import threading
import ray
from ray.util.queue import Queue
import psutil
import GPUtil
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
import random

def load_json(json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        _, ext = os.path.splitext(json_path)
        if ext == '.json':
            data = json.load(f)
            return data
        if ext == '.jsonl':
            data = []
            for item in f:
                data.append(json.loads(item))
            return data
        print("Wrong path, please provide .json/.jsonl file")

def save_json(save_path, json_list):
    with open(save_path, 'w', encoding='utf-8') as f:
        _, ext = os.path.splitext(save_path)
        try: 
            if ext == '.json':
                json.dump(json_list, f, ensure_ascii=False, indent=4)
                print(f"Successfully written to {save_path}")
            elif ext == '.jsonl':
                for item in json_list:
                    f.write(json.dumps(item, ensure_ascii=False) + '\n')
                print(f"Successfully written to {save_path}")
            else:
                raise Exception
        except:
            print("Wrong path, please provide .json/.jsonl file")

# Note: The num_gpus parameter will be set dynamically when creating workers
class ANSESearchWorker:
    def __init__(self, model_name, index_path, doc_ids_path, worker_id, cache_size=50000):
        """
        Initialize the ANCE Search Worker on specific GPU.
        
        :param model_name: Name of the ANCE transformer model for embeddings.
        :param index_path: Path to the FAISS index file.
        :param doc_ids_path: Path to the saved document IDs file.
        :param worker_id: Worker ID (will be mapped to GPU ID).
        :param cache_size: Size of the embedding cache.
        """
        self.model_name = model_name
        self.index_path = index_path
        self.doc_ids_path = doc_ids_path
        self.worker_id = worker_id
        self.cache_size = cache_size
        
        # Ray will automatically allocate GPU resources, we need to find the GPU assigned to this worker
        self.device = "cpu"
        self.gpu_id = None
        
        # Check if Ray has allocated GPU for this worker
        ray_gpu_ids = ray.get_gpu_ids()
        
        self.device = "cuda:0"
        # self.gpu_id = 0
        self.gpu_id = int(ray_gpu_ids[0])
        print(f"Worker {worker_id} - Ray allocated GPU IDs: {ray_gpu_ids}")
        
        # Add embedding cache mechanism
        self.embedding_cache = {}
        self.cache_order = []
        
        # Request statistics
        self.request_count = 0
        self.total_processing_time = 0
        self.current_load = 0
        
        # Load SentenceTransformer model
        self.model = SentenceTransformer(self.model_name, device=self.device)
        
        # Load FAISS index
        self.index = self._load_faiss_index()
        
        # Load document IDs
        self.doc_ids = self._load_doc_ids()
        
        print(f"✅ ANCE Worker {worker_id} initialized successfully, device: {self.device}")
    
    def _load_faiss_index(self):
        """Load the FAISS HNSW index from disk."""
        if not os.path.exists(self.index_path):
            raise FileNotFoundError(f"FAISS index file not found at {self.index_path}")
        return faiss.read_index(self.index_path)
    
    def _load_doc_ids(self):
        """Load document IDs from disk."""
        if not os.path.exists(self.doc_ids_path):
            raise FileNotFoundError(f"Document IDs file not found at {self.doc_ids_path}")
        return np.load(self.doc_ids_path, allow_pickle=True)
    
    def _get_cached_embedding(self, text):
        """Get or compute embedding (with cache)"""
        if text in self.embedding_cache:
            self.cache_order.remove(text)
            self.cache_order.append(text)
            return self.embedding_cache[text]
        
        embedding = self.get_dense_embedding(text)
        
        self.embedding_cache[text] = embedding
        self.cache_order.append(text)
        
        if len(self.embedding_cache) > self.cache_size:
            oldest_text = self.cache_order.pop(0)
            del self.embedding_cache[oldest_text]
        
        return embedding
    
    def get_dense_embedding(self, text):
        with torch.no_grad():
            embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
            return embedding[0]
    
    def batch_get_dense_embedding(self, texts):
        with torch.no_grad():
            embeddings = self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
            return embeddings
    
    def search(self, query, top_k=5, threads=8):
        """Execute search and update load statistics"""
        start_time = time.time()
        self.current_load += 1
        
        try:
            query_embedding = self._get_cached_embedding(query).astype(np.float32)
            # SentenceTransformer already returns normalized embeddings, so no need to normalize again
            faiss.omp_set_num_threads(threads)
            
            D, I = self.index.search(query_embedding, top_k)
            results = [(self.doc_ids[idx], D[0][i]) for i, idx in enumerate(I[0])]
            
            # Update statistics
            processing_time = time.time() - start_time
            self.request_count += 1
            self.total_processing_time += processing_time
            
            return results
        finally:
            self.current_load -= 1
    
    def batch_search(self, queries, top_k=5, threads=8):
        """Batch search"""
        start_time = time.time()
        self.current_load += len(queries)
        
        try:
            # Use batch embedding computation to improve efficiency
            uncached_queries = []
            cached_embeddings = []
            
            # Separate cached and uncached queries
            for query in queries:
                if query in self.embedding_cache:
                    cached_embeddings.append(self.embedding_cache[query])
                    # Update cache order
                    self.cache_order.remove(query)
                    self.cache_order.append(query)
                else:
                    uncached_queries.append(query)
            
            # Batch compute uncached embeddings
            if uncached_queries:
                new_embeddings = self.batch_get_dense_embedding(uncached_queries)
                
                # Add new embeddings to cache
                for query, embedding in zip(uncached_queries, new_embeddings):
                    self.embedding_cache[query] = embedding
                    self.cache_order.append(query)
                
                # Merge all embeddings
                all_embeddings = np.vstack(cached_embeddings + list(new_embeddings))
            else:
                all_embeddings = np.vstack(cached_embeddings)
            
            # Clean cache (if needed)
            while len(self.embedding_cache) > self.cache_size:
                oldest_text = self.cache_order.pop(0)
                del self.embedding_cache[oldest_text]
            
            # Convert to float32 and search
            query_embeddings = all_embeddings.astype(np.float32)
            # SentenceTransformer already returns normalized embeddings, so no need to normalize again
            faiss.omp_set_num_threads(threads)
            
            D, I = self.index.search(query_embeddings, top_k)
            results = [[(self.doc_ids[idx], D[i][j]) for j, idx in enumerate(I[i])] for i in range(len(queries))]
            
            # Update statistics
            processing_time = time.time() - start_time
            self.request_count += len(queries)
            self.total_processing_time += processing_time
            
            return results
        finally:
            self.current_load -= len(queries)
    
    def get_worker_stats(self):
        """Get worker statistics"""
        # Get GPU usage
        gpu_stats = {}
        try:
            gpus = GPUtil.getGPUs()
            if self.gpu_id < len(gpus):
                gpu = gpus[self.gpu_id]
                gpu_stats = {
                    "gpu_id": self.gpu_id,
                    "gpu_utilization": gpu.load * 100,
                    "gpu_memory_used": gpu.memoryUsed,
                    "gpu_memory_total": gpu.memoryTotal,
                    "gpu_memory_utilization": gpu.memoryUtil * 100,
                    "gpu_temperature": gpu.temperature
                }
        except:
            gpu_stats = {"gpu_id": self.gpu_id, "gpu_utilization": 0}
        
        return {
            "worker_id": self.worker_id,
            "device": self.device,
            "request_count": self.request_count,
            "current_load": self.current_load,
            "avg_processing_time": self.total_processing_time / max(1, self.request_count),
            "cache_size": len(self.embedding_cache),
            "cache_capacity": self.cache_size,
            "gpu_stats": gpu_stats
        }
    
    def clear_cache(self):
        """Clear cache"""
        old_cache_size = len(self.embedding_cache)
        self.embedding_cache.clear()
        self.cache_order.clear()
        return old_cache_size

# Configuration class
class Config:
    def __init__(self):
        self.model_name = 'sentence-transformers/msmarco-roberta-base-ance-firstp'
        self.output_dir = "local_index_search/wikipedia_segment/ance"
        self.host = "0.0.0.0"
        self.port = 8000
        self.num_gpus = 2
        self.gpu_allocate_per_ray_instance = 1.0
        
        # Read configuration from environment variables
        self.model_name = os.getenv('MODEL_NAME', self.model_name)
        self.output_dir = os.getenv('OUTPUT_DIR', self.output_dir)
        self.host = os.getenv('HOST', self.host)
        self.port = int(os.getenv('PORT', self.port))
        self.num_gpus = int(os.getenv('NUM_GPUS', self.num_gpus))
        self.gpu_allocate_per_ray_instance = float(os.getenv('GPU_ALLOCATE_PER_RAY_INSTANCE', self.gpu_allocate_per_ray_instance))
        
        # Construct paths
        self.index_path = os.path.join(self.output_dir, 'ance_faiss_index_64_2000.bin')
        self.doc_ids_path = os.path.join(self.output_dir, 'doc_ids.npy')
    
    def from_args(self, args):
        """Update configuration from command line arguments"""
        if args.model_name is not None:
            self.model_name = args.model_name
        if args.output_dir is not None:
            self.output_dir = args.output_dir
            self.index_path = os.path.join(self.output_dir, 'ance_faiss_index_64_2000.bin')
            self.doc_ids_path = os.path.join(self.output_dir, 'doc_ids.npy')
        if args.host is not None:
            self.host = args.host
        if args.port is not None:
            self.port = args.port
        if args.num_gpus is not None:
            self.num_gpus = args.num_gpus
        if args.gpu_allocate_per_ray_instance is not None:
            self.gpu_allocate_per_ray_instance = args.gpu_allocate_per_ray_instance
    
    def print_config(self):
        """Print current configuration"""
        print("=== Multi-GPU ANCE Search Server Configuration ===")
        print(f"Model path: {self.model_name}")
        print(f"Index directory: {self.output_dir}")
        print(f"Index file: {self.index_path}")
        print(f"Document ID file: {self.doc_ids_path}")
        print(f"Number of GPUs: {self.num_gpus}")
        print(f"Number of GPUs allocated per Ray instance: {self.gpu_allocate_per_ray_instance}")
        print(f"Host: {self.host}")
        print(f"Port: {self.port}")
        print("============================")

# Concurrency monitoring middleware
class ConcurrencyMonitorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, max_concurrent_requests: int = 5000):
        super().__init__(app)
        self.max_concurrent_requests = max_concurrent_requests
        self.current_requests = 0
        self.lock = threading.Lock()
        self.request_stats = defaultdict(int)
    
    async def dispatch(self, request: Request, call_next):
        with self.lock:
            if self.current_requests >= self.max_concurrent_requests:
                return Response(
                    content="Too many concurrent requests",
                    status_code=429,
                    headers={"Retry-After": "1"}
                )
            self.current_requests += 1
            self.request_stats[request.url.path] += 1
        
        try:
            start_time = time.time()
            response = await call_next(request)
            process_time = time.time() - start_time
            
            response.headers["X-Process-Time"] = str(process_time)
            response.headers["X-Current-Requests"] = str(self.current_requests)
            
            return response
        finally:
            with self.lock:
                self.current_requests -= 1

# Global variables
workers = []
config = None

def initialize_ray_cluster():
    """Initialize Ray cluster and workers"""
    global workers, config
    
    # Clean up existing Ray instance
    if ray.is_initialized():
        ray.shutdown()
    
    # Check available GPU count
    available_gpus = 0
    if torch.cuda.is_available():
        available_gpus = torch.cuda.device_count()
        print(f"Main process detected {available_gpus} GPUs")
        
        if config.num_gpus > available_gpus:
            print(f"⚠️  Requested GPU count ({config.num_gpus}) exceeds available count ({available_gpus}), adjusting to {available_gpus}")
            config.num_gpus = available_gpus
    else:
        print("❌ Main process did not detect GPU, workers will use CPU")
        config.num_gpus = 2  # Still create multiple workers, but will use CPU
    
    # Explicitly set GPU resources when initializing Ray
    ray_init_kwargs = {
        "ignore_reinit_error": True,
        "log_to_driver": True,
        "configure_logging": False,
        "include_dashboard": False,
    }
    
    # Set GPU resource count
    if available_gpus > 0:
        ray_init_kwargs["num_gpus"] = config.num_gpus
        print(f"Initializing Ray with {config.num_gpus} GPUs...")
    else:
        print("Initializing Ray with CPU only...")
    
    ray.init(**ray_init_kwargs)
    
    # Wait for Ray to fully initialize
    import time
    time.sleep(3)
    
    # Check Ray resources
    resources = ray.available_resources()
    print(f"Ray available resources: {resources}")
    
    # Create workers
    workers = []
    # Changed to multiple
    # for worker_id in range(config.num_gpus):
    total_workers = int(config.num_gpus / config.gpu_allocate_per_ray_instance) 
    for worker_id in range(total_workers):
        try:
            # When creating worker, Ray will automatically allocate GPU resources
            # Use ray.remote decorator to dynamically set GPU count
            RemoteWorker = ray.remote(num_gpus=config.gpu_allocate_per_ray_instance)(ANSESearchWorker)
            worker = RemoteWorker.remote(
                model_name=config.model_name,
                index_path=config.index_path,
                doc_ids_path=config.doc_ids_path,
                worker_id=worker_id,
                cache_size=50000
            )
            workers.append(worker)
            print(f"✅ Created Worker {worker_id} successfully")
            
        except Exception as e:
            print(f"❌ Failed to create Worker {worker_id}: {e}")
            if len(workers) == 0:
                raise Exception(f"Unable to create any worker: {e}")
    
    
    print(f"✅ Successfully initialized {len(workers)} workers")

def get_random_worker():
    """Randomly select a worker"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    return random.choice(workers)

async def get_all_worker_stats():
    """Get statistics for all workers"""
    if not workers:
        return []
    
    stats = []
    for worker in workers:
        try:
            worker_stats = await worker.get_worker_stats.remote()
            stats.append(worker_stats)
        except Exception as e:
            print(f"Failed to get worker statistics: {e}")
    return stats

# FastAPI application
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Initialize Ray cluster on startup
    initialize_ray_cluster()
    yield
    # Clean up Ray resources on shutdown
    if ray.is_initialized():
        ray.shutdown()

app = FastAPI(
    title="Multi-GPU ANCE FAISS Search Service", 
    version="1.0.0",
    lifespan=lifespan
)
app.add_middleware(ConcurrencyMonitorMiddleware, max_concurrent_requests=100000)

# Request models
class SearchRequest(BaseModel):
    queries: List[str]
    top_k: int = 5
    threads: int = 8

class SearchResult(BaseModel):
    doc_id: int
    score: float

class SearchResponse(BaseModel):
    results: List[List[SearchResult]]
    search_time: float

# Remove old event handlers, replaced by lifespan

@app.get("/")
async def root():
    return {"message": "Multi-GPU ANCE FAISS Search Service is running"}

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    return {"status": "healthy", "num_workers": len(workers)}

@app.get("/stats")
async def get_stats():
    """Get statistics for all workers"""
    if not workers:
        return {"status": "healthy", "num_workers": 0, "workers": [], "load_balancer_stats": {}, "timestamp": time.time()}
    
    worker_stats = await get_all_worker_stats()
    
    return {
        "status": "healthy",
        "num_workers": len(workers),
        "workers": worker_stats,
        "load_balancer_stats": {}, # Load balancer removed, this field is empty
        "timestamp": time.time()
    }

@app.get("/config")
async def get_config():
    """Get current configuration"""
    if config is None:
        raise HTTPException(status_code=503, detail="Config not initialized")
    
    return {
        "model_name": config.model_name,
        "output_dir": config.output_dir,
        "index_path": config.index_path,
        "doc_ids_path": config.doc_ids_path,
        "num_gpus": config.num_gpus,
        "host": config.host,
        "port": config.port
    }

@app.post("/search", response_model=SearchResponse)
async def batch_search(request: SearchRequest):
    """Batch search endpoint"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    
    if not request.queries:
        raise HTTPException(status_code=400, detail="Queries list cannot be empty")
    
    try:
        # Randomly select a worker
        worker = get_random_worker()
        
        # Execute search
        time_start = time.time()
        results = await worker.batch_search.remote(
            queries=request.queries,
            top_k=request.top_k,
            threads=request.threads
        )
        time_end = time.time()
        search_time = time_end - time_start
        
        # Convert result format
        formatted_results = []
        for query_results in results:
            query_formatted = [
                SearchResult(doc_id=int(doc_id), score=float(score))
                for doc_id, score in query_results
            ]
            formatted_results.append(query_formatted)
        
        return SearchResponse(
            results=formatted_results,
            search_time=search_time
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")

@app.post("/clear_cache")
async def clear_cache():
    """Clear cache for all workers"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    
    clear_results = []
    for worker in workers:
        cleared_items = await worker.clear_cache.remote()
        clear_results.append(cleared_items)
    
    return {
        "message": "Cache cleared successfully",
        "workers_cleared": len(clear_results),
        "total_cleared_items": sum(clear_results),
        "timestamp": time.time()
    }

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Multi-GPU ANCE FAISS Search Service")
    parser.add_argument("--model-name", type=str, 
                       help="Path to the ANCE transformer model")
    parser.add_argument("--output-dir", type=str, 
                       help="Directory containing FAISS index and doc IDs")
    parser.add_argument("--host", type=str, 
                       help="Host to bind the server")
    parser.add_argument("--port", type=int, 
                       help="Port to bind the server")
    parser.add_argument("--num-gpus", type=int, 
                       help="Number of GPUs to use")
    parser.add_argument("--gpu-allocate-per-ray-instance", type=float, default=1.0,
                       help="Number of GPUs to allocate per Ray instance (default: 1.0, can be fractional like 0.5)")
    return parser.parse_args()

if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    
    # Initialize configuration
    config = Config()
    config.from_args(args)
    config.print_config()
    
    # Run server
    uvicorn.run(
        app, 
        host=config.host, 
        port=config.port,
        log_level="info",
        workers=1,
        limit_concurrency=100000,
        timeout_keep_alive=60,
        timeout_graceful_shutdown=60,
        loop="asyncio",
        http="httptools",
        ws="websockets",
        lifespan="on",
        access_log=False,
        backlog=4096,
        h11_max_incomplete_event_size=16 * 1024 * 1024,
    ) 