import json
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import os
import time
import argparse
from fastapi import FastAPI, HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel
from typing import List, Tuple, Optional
import uvicorn
import asyncio
from collections import defaultdict
import threading
import ray
from ray.util.queue import Queue
import psutil
import GPUtil
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
import random

GPU_ALLOCATE_PER_RAY_INSTANCE = 1 # Default is 1,2
parellel_load_count = 3
index_file_name = 'ance_faiss_index_flat.bin'
# index_file_name = 'ance_faiss_index_64_2000.bin'


def load_json(json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        _, ext = os.path.splitext(json_path)
        if ext == '.json':
            data = json.load(f)
            return data
        if ext == '.jsonl':
            data = []
            for item in f:
                data.append(json.loads(item))
            return data
        print("Wrong path, please provide .json/.jsonl file")

def save_json(save_path, json_list):
    with open(save_path, 'w', encoding='utf-8') as f:
        _, ext = os.path.splitext(save_path)
        try: 
            if ext == '.json':
                json.dump(json_list, f, ensure_ascii=False, indent=4)
                print(f"Successfully written to {save_path}")
            elif ext == '.jsonl':
                for item in json_list:
                    f.write(json.dumps(item, ensure_ascii=False) + '\n')
                print(f"Successfully written to {save_path}")
            else:
                raise Exception
        except:
            print("Wrong path, please provide .json/.jsonl file")

def build_faiss_index(index_path: str, n_gpu: int, faiss_gpu: bool=True):
    """Build FAISS index with GPU acceleration, similar to server.py"""
    gpu_resources = []
    tempmem = -1
    for i in range(n_gpu):
        res = faiss.StandardGpuResources()
        if tempmem >= 0:
            res.setTempMemory(tempmem)
        gpu_resources.append(res)

    # Load the pre-built index
    print(f"Loading FAISS index from {index_path}...")
    cpu_index = faiss.read_index(index_path)
    index = None
    
    if faiss_gpu and n_gpu > 0:
        print(f"Using {n_gpu} GPU(s) for FAISS")
        co = faiss.GpuMultipleClonerOptions()
        co.shard = True
        co.usePrecomputed = False
        vres = faiss.GpuResourcesVector()
        vdev = faiss.Int32Vector()
        for i in range(0, n_gpu):
            vdev.push_back(i)
            vres.push_back(gpu_resources[i])
        gpu_index = faiss.index_cpu_to_gpu_multiple(vres, vdev, cpu_index, co)
        index = gpu_index
    else:
        print("Using CPU for FAISS")
        index = cpu_index

    print(f"FAISS index loaded successfully with {index.ntotal} vectors")
    return index

@ray.remote(num_gpus=GPU_ALLOCATE_PER_RAY_INSTANCE)
class ANSESearchWorker:
    def __init__(self, model_name, index_path, doc_ids_path, worker_id, n_gpu=1, faiss_gpu=True, cache_size=50000):
        """
        Initialize the ANCE Search Worker on specific GPU.
        
        :param model_name: Name of the ANCE transformer model for embeddings.
        :param index_path: Path to the FAISS index file.
        :param doc_ids_path: Path to the saved document IDs file.
        :param worker_id: Worker ID (will be mapped to GPU ID).
        :param n_gpu: Number of GPUs to use for FAISS.
        :param faiss_gpu: Whether to use GPU for FAISS.
        :param cache_size: Size of the embedding cache.
        """
        self.model_name = model_name
        self.index_path = index_path
        self.doc_ids_path = doc_ids_path
        self.worker_id = worker_id
        # self.n_gpu = n_gpu
        self.n_gpu = GPU_ALLOCATE_PER_RAY_INSTANCE
        self.faiss_gpu = faiss_gpu
        self.cache_size = cache_size
        
        # Ray will automatically allocate GPU resources, we need to find the GPU assigned to this worker
        self.device = "cpu"
        self.gpu_id = None
        
        # Check if Ray has allocated GPU for this worker
        ray_gpu_ids = ray.get_gpu_ids()
        
        if ray_gpu_ids:
            self.device = "cuda:0"
            self.gpu_id = int(ray_gpu_ids[0])
            print(f"Worker {worker_id} - Ray allocated GPU IDs: {ray_gpu_ids}")
        else:
            print(f"Worker {worker_id} - Using CPU")
        
        # Add embedding cache mechanism
        self.embedding_cache = {}
        self.cache_order = []
        
        # Request statistics
        self.request_count = 0
        self.total_processing_time = 0
        self.current_load = 0
        
        # Load SentenceTransformer model
        self.model = SentenceTransformer(self.model_name, device=self.device)
        
        # Load FAISS index with GPU acceleration
        self.index = self._load_faiss_index()
        
        # Load document IDs
        self.doc_ids = self._load_doc_ids()
        
        print(f"✅ ANCE Worker {worker_id} initialized successfully, device: {self.device}")
    
    def _load_faiss_index(self):
        """Load the FAISS index from disk with GPU acceleration."""
        if not os.path.exists(self.index_path):
            raise FileNotFoundError(f"FAISS index file not found at {self.index_path}")
        
        # Use the build_faiss_index function similar to server.py
        return build_faiss_index(self.index_path, self.n_gpu, self.faiss_gpu)
    
    def _load_doc_ids(self):
        """Load document IDs from disk."""
        if not os.path.exists(self.doc_ids_path):
            raise FileNotFoundError(f"Document IDs file not found at {self.doc_ids_path}")
        return np.load(self.doc_ids_path, allow_pickle=True)
    
    def _get_cached_embedding(self, text):
        """Get or compute embedding (with cache)"""
        if text in self.embedding_cache:
            self.cache_order.remove(text)
            self.cache_order.append(text)
            return self.embedding_cache[text]
        
        embedding = self.get_dense_embedding(text)
        
        self.embedding_cache[text] = embedding
        self.cache_order.append(text)
        
        if len(self.embedding_cache) > self.cache_size:
            oldest_text = self.cache_order.pop(0)
            del self.embedding_cache[oldest_text]
        
        return embedding
    
    def get_dense_embedding(self, text):
        with torch.no_grad():
            embedding = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)
            return embedding[0]
    
    def batch_get_dense_embedding(self, texts):
        with torch.no_grad():
            embeddings = self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
            return embeddings
    
    def search(self, query, top_k=5, threads=8):
        """Execute search and update load statistics"""
        start_time = time.time()
        self.current_load += 1
        
        try:
            query_embedding = self._get_cached_embedding(query).astype(np.float32)
            # SentenceTransformer already returns normalized embeddings, so no need to normalize again
            
            # Set number of threads for FAISS (similar to server.py)
            faiss.omp_set_num_threads(threads)
            
            # Reshape for FAISS search
            query_embedding = query_embedding.reshape(1, -1)
            D, I = self.index.search(query_embedding, top_k)
            results = [(self.doc_ids[idx], D[0][i]) for i, idx in enumerate(I[0])]
            
            # Update statistics
            processing_time = time.time() - start_time
            self.request_count += 1
            self.total_processing_time += processing_time
            
            return results
        finally:
            self.current_load -= 1
    
    def batch_search(self, queries, top_k=5, threads=8):
        """Batch search"""
        start_time = time.time()
        self.current_load += len(queries)
        
        try:
            # Use batch embedding computation to improve efficiency
            uncached_queries = []
            cached_embeddings = []
            query_to_embedding = {}
            
            # Separate cached and uncached queries
            for query in queries:
                if query in self.embedding_cache:
                    cached_embedding = self.embedding_cache[query]
                    query_to_embedding[query] = cached_embedding
                    # Update cache order
                    self.cache_order.remove(query)
                    self.cache_order.append(query)
                else:
                    uncached_queries.append(query)
            
            # Batch compute uncached embeddings
            if uncached_queries:
                new_embeddings = self.batch_get_dense_embedding(uncached_queries)
                
                # Add new embeddings to cache
                for query, embedding in zip(uncached_queries, new_embeddings):
                    self.embedding_cache[query] = embedding
                    self.cache_order.append(query)
                    query_to_embedding[query] = embedding
            
            # Clean cache (if needed)
            while len(self.embedding_cache) > self.cache_size:
                oldest_text = self.cache_order.pop(0)
                del self.embedding_cache[oldest_text]
            
            # Build embeddings matrix in original order
            all_embeddings = np.array([query_to_embedding[query] for query in queries])
            
            # Convert to float32 and search
            query_embeddings = all_embeddings.astype(np.float32)
            # SentenceTransformer already returns normalized embeddings, so no need to normalize again
            
            # Set number of threads for FAISS (similar to server.py)
            faiss.omp_set_num_threads(threads)
            
            D, I = self.index.search(query_embeddings, top_k)
            results = [[(self.doc_ids[idx], D[i][j]) for j, idx in enumerate(I[i])] for i in range(len(queries))]
            
            # Update statistics
            processing_time = time.time() - start_time
            self.request_count += len(queries)
            self.total_processing_time += processing_time
            
            return results
        finally:
            self.current_load -= len(queries)
    
    def get_worker_stats(self):
        """Get worker statistics"""
        # Get GPU usage
        gpu_stats = {}
        try:
            gpus = GPUtil.getGPUs()
            if self.gpu_id is not None and self.gpu_id < len(gpus):
                gpu = gpus[self.gpu_id]
                gpu_stats = {
                    "gpu_id": self.gpu_id,
                    "gpu_utilization": gpu.load * 100,
                    "gpu_memory_used": gpu.memoryUsed,
                    "gpu_memory_total": gpu.memoryTotal,
                    "gpu_memory_utilization": gpu.memoryUtil * 100,
                    "gpu_temperature": gpu.temperature
                }
        except:
            gpu_stats = {"gpu_id": self.gpu_id, "gpu_utilization": 0}
        
        return {
            "worker_id": self.worker_id,
            "device": self.device,
            "request_count": self.request_count,
            "current_load": self.current_load,
            "avg_processing_time": self.total_processing_time / max(1, self.request_count),
            "cache_size": len(self.embedding_cache),
            "cache_capacity": self.cache_size,
            "gpu_stats": gpu_stats,
            "faiss_gpu": self.faiss_gpu,
            "n_gpu": self.n_gpu
        }
    
    def clear_cache(self):
        """Clear cache"""
        old_cache_size = len(self.embedding_cache)
        self.embedding_cache.clear()
        self.cache_order.clear()
        return old_cache_size

# Configuration class
class Config:
    def __init__(self):
        self.model_name = 'sentence-transformers/msmarco-roberta-base-ance-firstp'
        self.output_dir = "local_index_search/wikipedia_segment/ance"
        self.host = "0.0.0.0"
        self.port = 8000
        self.num_gpus = 2
        self.faiss_gpu = True  # New GPU acceleration option
        
        # Read configuration from environment variables
        self.model_name = os.getenv('MODEL_NAME', self.model_name)
        self.output_dir = os.getenv('OUTPUT_DIR', self.output_dir)
        self.host = os.getenv('HOST', self.host)
        self.port = int(os.getenv('PORT', self.port))
        self.num_gpus = int(os.getenv('NUM_GPUS', self.num_gpus))
        self.faiss_gpu = os.getenv('FAISS_GPU', 'true').lower() == 'true'
        
        # Construct paths
        self.index_path = os.path.join(self.output_dir, 'ance_faiss_index_flat.bin')
        self.doc_ids_path = os.path.join(self.output_dir, 'doc_ids.npy')
    
    def from_args(self, args):
        """Update configuration from command line arguments"""
        if args.model_name is not None:
            self.model_name = args.model_name
        if args.output_dir is not None:
            self.output_dir = args.output_dir
            # self.index_path = os.path.join(self.output_dir, 'ance_faiss_index_flat.bin')
            self.index_path = os.path.join(self.output_dir, index_file_name)
            # self.index_path = os.path.join(self.output_dir, 'ance_faiss_index_64_2000.bin')
            self.doc_ids_path = os.path.join(self.output_dir, 'doc_ids.npy')
        if args.host is not None:
            self.host = args.host
        if args.port is not None:
            self.port = args.port
        if args.num_gpus is not None:
            self.num_gpus = args.num_gpus
        if hasattr(args, 'faiss_gpu') and args.faiss_gpu is not None:
            self.faiss_gpu = args.faiss_gpu
    
    def print_config(self):
        """Print current configuration"""
        print("=== Multi-GPU ANCE Search Server Configuration ===")
        print(f"Model path: {self.model_name}")
        print(f"Index directory: {self.output_dir}")
        print(f"Index file: {self.index_path}")
        print(f"Document ID file: {self.doc_ids_path}")
        print(f"Number of GPUs: {self.num_gpus}")
        print(f"FAISS GPU acceleration: {self.faiss_gpu}")
        print(f"Host: {self.host}")
        print(f"Port: {self.port}")
        print("============================")

# Concurrency monitoring middleware
class ConcurrencyMonitorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, max_concurrent_requests: int = 5000):
        super().__init__(app)
        self.max_concurrent_requests = max_concurrent_requests
        self.current_requests = 0
        self.lock = threading.Lock()
        self.request_stats = defaultdict(int)
    
    async def dispatch(self, request: Request, call_next):
        with self.lock:
            if self.current_requests >= self.max_concurrent_requests:
                return Response(
                    content="Too many concurrent requests",
                    status_code=429,
                    headers={"Retry-After": "1"}
                )
            self.current_requests += 1
            self.request_stats[request.url.path] += 1
        
        try:
            start_time = time.time()
            response = await call_next(request)
            process_time = time.time() - start_time
            
            response.headers["X-Process-Time"] = str(process_time)
            response.headers["X-Current-Requests"] = str(self.current_requests)
            
            return response
        finally:
            with self.lock:
                self.current_requests -= 1

# Global variables
workers = []
config = None

def initialize_ray_cluster():
    """Initialize Ray cluster and workers"""
    global workers, config
    
    # Clean up existing Ray instance
    if ray.is_initialized():
        ray.shutdown()
    
    # Check available GPU count
    available_gpus = 0
    if torch.cuda.is_available():
        available_gpus = torch.cuda.device_count()
        print(f"Main process detected {available_gpus} GPUs")
        
        if config.num_gpus > available_gpus:
            print(f"⚠️  Requested GPU count ({config.num_gpus}) exceeds available count ({available_gpus}), adjusting to {available_gpus}")
            config.num_gpus = available_gpus
    else:
        print("❌ Main process did not detect GPU, workers will use CPU")
        config.faiss_gpu = False  # Force CPU usage
    
    # Explicitly set GPU resources when initializing Ray
    ray_init_kwargs = {
        "ignore_reinit_error": True,
        "log_to_driver": True,
        "configure_logging": False,
        "include_dashboard": False,
    }
    
    # Set GPU resource count
    if available_gpus > 0 and config.faiss_gpu:
        ray_init_kwargs["num_gpus"] = config.num_gpus
        print(f"Initializing Ray with {config.num_gpus} GPUs...")
    else:
        print("Initializing Ray with CPU only...")
    
    ray.init(**ray_init_kwargs)
    
    # Wait for Ray to fully initialize
    import time
    time.sleep(3)
    
    # Check Ray resources
    resources = ray.available_resources()
    print(f"Ray available resources: {resources}")
    
    # Create workers
    workers = []
    import time

    # Record start time
    start_time = time.time()
    total_workers = max(1, int(config.num_gpus / GPU_ALLOCATE_PER_RAY_INSTANCE)) if config.faiss_gpu else 1
    for worker_id in range(total_workers):
        try:
            # When creating worker, Ray will automatically allocate GPU resources
            worker = ANSESearchWorker.remote(
                model_name=config.model_name,
                index_path=config.index_path,
                doc_ids_path=config.doc_ids_path,
                worker_id=worker_id,
                n_gpu=config.num_gpus if config.faiss_gpu else 0,
                faiss_gpu=config.faiss_gpu,
                cache_size=50000
            )

            # Serial creation
            # Wait for current worker to fully initialize, ensuring serialization
            if (worker_id + 1) % parellel_load_count == 0:
                ray.get(worker.get_worker_stats.remote())
                end_time = time.time()
                print(f"One load execution time: {end_time - start_time} seconds")
                start_time = end_time
            workers.append(worker)
            print(f"✅ Created Worker {worker_id} successfully")
            
        except Exception as e:
            print(f"❌ Failed to create Worker {worker_id}: {e}")
            if len(workers) == 0:
                raise Exception(f"Unable to create any worker: {e}")
    
    print(f"✅ Successfully initialized {len(workers)} workers")

def get_random_worker():
    """Randomly select a worker"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    return random.choice(workers)

async def get_all_worker_stats():
    """Get statistics for all workers"""
    if not workers:
        return []
    
    stats = []
    for worker in workers:
        try:
            worker_stats = await worker.get_worker_stats.remote()
            stats.append(worker_stats)
        except Exception as e:
            print(f"Failed to get worker statistics: {e}")
    return stats

# FastAPI application
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Initialize Ray cluster on startup
    initialize_ray_cluster()
    yield
    # Clean up Ray resources on shutdown
    if ray.is_initialized():
        ray.shutdown()

app = FastAPI(
    title="Multi-GPU ANCE FAISS Search Service", 
    version="1.0.0",
    lifespan=lifespan
)
app.add_middleware(ConcurrencyMonitorMiddleware, max_concurrent_requests=100000)

# Request models
class SearchRequest(BaseModel):
    queries: List[str]
    top_k: int = 5
    threads: int = 8

class SearchResult(BaseModel):
    doc_id: int
    score: float

class SearchResponse(BaseModel):
    results: List[List[SearchResult]]
    search_time: float

@app.get("/")
async def root():
    return {"message": "Multi-GPU ANCE FAISS Search Service is running"}

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    return {"status": "healthy", "num_workers": len(workers)}

@app.get("/stats")
async def get_stats():
    """Get statistics for all workers"""
    if not workers:
        return {"status": "healthy", "num_workers": 0, "workers": [], "load_balancer_stats": {}, "timestamp": time.time()}
    
    worker_stats = await get_all_worker_stats()
    
    return {
        "status": "healthy",
        "num_workers": len(workers),
        "workers": worker_stats,
        "load_balancer_stats": {}, # Load balancer removed, this field is empty
        "timestamp": time.time()
    }

@app.get("/config")
async def get_config():
    """Get current configuration"""
    if config is None:
        raise HTTPException(status_code=503, detail="Config not initialized")
    
    return {
        "model_name": config.model_name,
        "output_dir": config.output_dir,
        "index_path": config.index_path,
        "doc_ids_path": config.doc_ids_path,
        "num_gpus": config.num_gpus,
        "faiss_gpu": config.faiss_gpu,
        "host": config.host,
        "port": config.port
    }

@app.post("/search", response_model=SearchResponse)
async def batch_search(request: SearchRequest):
    """Batch search endpoint"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    
    if not request.queries:
        raise HTTPException(status_code=400, detail="Queries list cannot be empty")
    
    try:
        # Randomly select a worker
        worker = get_random_worker()
        
        # Execute search
        time_start = time.time()
        results = await worker.batch_search.remote(
            queries=request.queries,
            top_k=request.top_k,
            threads=request.threads
        )
        time_end = time.time()
        search_time = time_end - time_start
        
        # Convert result format
        formatted_results = []
        for query_results in results:
            query_formatted = [
                SearchResult(doc_id=int(doc_id), score=float(score))
                for doc_id, score in query_results
            ]
            formatted_results.append(query_formatted)
        
        return SearchResponse(
            results=formatted_results,
            search_time=search_time
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")

@app.post("/clear_cache")
async def clear_cache():
    """Clear cache for all workers"""
    if not workers:
        raise HTTPException(status_code=503, detail="No workers available")
    
    clear_results = []
    for worker in workers:
        cleared_items = await worker.clear_cache.remote()
        clear_results.append(cleared_items)
    
    return {
        "message": "Cache cleared successfully",
        "workers_cleared": len(clear_results),
        "total_cleared_items": sum(clear_results),
        "timestamp": time.time()
    }

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Multi-GPU ANCE FAISS Search Service")
    parser.add_argument("--model-name", type=str, 
                       help="Path to the ANCE transformer model")
    parser.add_argument("--output-dir", type=str, 
                       help="Directory containing FAISS index and doc IDs")
    parser.add_argument("--host", type=str, 
                       help="Host to bind the server")
    parser.add_argument("--port", type=int, 
                       help="Port to bind the server")
    parser.add_argument("--num-gpus", type=int, 
                       help="Number of GPUs to use")
    parser.add_argument("--faiss-gpu", action='store_true', default=True,
                       help="Use GPU for FAISS (default: True)")
    parser.add_argument("--no-faiss-gpu", dest='faiss_gpu', action='store_false',
                       help="Disable GPU for FAISS")
    return parser.parse_args()

if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    
    # Initialize configuration
    config = Config()
    config.from_args(args)
    config.print_config()
    
    # Run server
    uvicorn.run(
        app, 
        host=config.host, 
        port=config.port,
        log_level="info",
        workers=1,
        limit_concurrency=100000,
        timeout_keep_alive=60,
        timeout_graceful_shutdown=60,
        loop="asyncio",
        http="httptools",
        ws="websockets",
        lifespan="on",
        access_log=False,
        backlog=4096,
        h11_max_incomplete_event_size=16 * 1024 * 1024,
    )
