import json
import os
import time
import argparse
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import uvicorn
import asyncio
import random
import ray
from contextlib import asynccontextmanager
from pyserini.search.lucene import LuceneSearcher

# Core search class: using Pyserini multi-field search
class PyseriniMultiFieldSearch:
    def __init__(self, index_dir, k1=0.9, b=0.4):
        print(f"Loading index from {index_dir}")
        if not os.path.exists(index_dir):
            raise FileNotFoundError(f"Index directory does not exist: {index_dir}")
        self.searcher = LuceneSearcher(index_dir)
        print("Index loading completed")
        # Set BM25 parameters
        self.searcher.set_bm25(k1, b)
        print(f"BM25 parameters set: k1={k1}, b={b}")

    def batch_search(self, queries, top_k=10, threads=4):
        """Batch parallel search (CPU only)"""
        # Construct field queries (using contents field here)
        # field_queries = [f"contents:{query}" for query in queries]
        # Fix: use parentheses to ensure all words are searched in contents field
        field_queries = [f"contents:({query.strip()})" if query.strip() else "contents:*" for query in queries]
        # Execute batch search
        results_dict = self.searcher.batch_search(
            field_queries,
            [str(i) for i in range(len(queries))],  # Unique query ID
            k=top_k,
            threads=threads  # CPU parallel thread count
        )
        
        # Format results
        final_results = {}
        for i, query in enumerate(queries):
            hits = results_dict[str(i)]
            formatted = [
                (json.loads(hit.raw)["id"], json.loads(hit.raw)["contents"], hit.score)
                for hit in hits
            ]
            final_results[query] = formatted
        return final_results

# Ray Worker: runs on CPU only
@ray.remote
class LuceneSearchWorker:
    def __init__(self, index_dir, worker_id, k1=0.9, b=0.4):
        self.searcher = PyseriniMultiFieldSearch(index_dir, k1, b)
        self.worker_id = worker_id
        print(f"✅ Worker {worker_id} initialization completed (CPU mode)")

    def batch_search(self, queries, top_k=10, threads=4):
        """Receive queries and return results"""
        return self.searcher.batch_search(queries, top_k, threads)

# Configuration class (only essential parameters)
class Config:
    def __init__(self):
        self.index_dir = "lucene_index"  # Default index directory
        self.host = "0.0.0.0"
        self.port = 8000
        self.num_cpus = 4  # CPU Worker count
        self.k1 = 0.9  # BM25 parameter k1
        self.b = 0.4   # BM25 parameter b

    def from_args(self, args):
        if args.index_dir:
            self.index_dir = args.index_dir
        if args.host:
            self.host = args.host
        if args.port:
            self.port = args.port
        if args.num_cpus:
            self.num_cpus = args.num_cpus
        if args.k1 is not None:
            self.k1 = args.k1
        if args.b is not None:
            self.b = args.b

    def print_config(self):
        print(f"=== Configuration Information ===")
        print(f"Index directory: {self.index_dir}")
        print(f"CPU Worker count: {self.num_cpus}")
        print(f"Service address: {self.host}:{self.port}")
        print(f"BM25 parameters: k1={self.k1}, b={self.b}")

# Initialize Ray and Workers
@asynccontextmanager
async def lifespan(app: FastAPI):
    global workers, config
    ray.init(ignore_reinit_error=True)
    workers = []
    # Add wait time to ensure Ray is ready
    time.sleep(3)
    for i in range(config.num_cpus):
        try:
            worker = LuceneSearchWorker.remote(config.index_dir, i, config.k1, config.b)
            workers.append(worker)
            print(f"✅ Worker {i} created successfully")
        except Exception as e:
            print(f"❌ Worker {i} creation failed: {e}")
    # Ensure at least one Worker
    if not workers:
        raise Exception("All Workers failed to create, please check index directory and Ray configuration")
    print(f"Initialization completed {len(workers)} CPU Workers")
    yield
    ray.shutdown()

# FastAPI related
app = FastAPI(title="Pyserini CPU Parallel Search Service", lifespan=lifespan)
# app.add_middleware(ConcurrencyMonitorMiddleware, max_concurrent_requests=100000)
workers = []  # Store Ray Worker references
config = None

# Request/Response models
class SearchRequest(BaseModel):
    queries: List[str]
    top_k: int = 10
    threads: int = 4  # CPU thread count per query

class SearchResponse(BaseModel):
    results: dict  # {query: [(id, content, score), ...]}
    search_time: float

# API endpoints
@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
    if not workers:
        raise HTTPException(status_code=503, detail="No available Workers")
    if not request.queries:
        raise HTTPException(status_code=400, detail="Query list cannot be empty")
    
    # Randomly select a Worker
    worker = random.choice(workers)
    
    # Execute search
    start_time = time.time()
    # Call Ray Worker batch search method
    results = await worker.batch_search.remote(
        request.queries,
        request.top_k,
        request.threads
    )
    search_time = time.time() - start_time
    
    return SearchResponse(results=results, search_time=search_time)

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "num_workers": len(workers)
    }

# Command line argument parsing
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--index-dir", type=str, help="Index directory path")
    parser.add_argument("--host", type=str, help="Service host address")
    parser.add_argument("--port", type=int, help="Service port")
    parser.add_argument("--num-cpus", type=int, help="CPU Worker count")
    parser.add_argument("--k1", type=float, required=True, help="BM25 parameter k1 (required)")
    parser.add_argument("--b", type=float, required=True, help="BM25 parameter b (required)")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    config = Config()
    config.from_args(args)
    config.print_config()
    
    # Run server
    uvicorn.run(
        app, 
        host=config.host, 
        port=config.port,
        log_level="info",
        workers=1,
        limit_concurrency=100000,
        timeout_keep_alive=60,
        timeout_graceful_shutdown=60,
        loop="asyncio",
        http="httptools",
        ws="websockets",
        lifespan="on",
        access_log=False,
        backlog=4096,
        h11_max_incomplete_event_size=16 * 1024 * 1024,
    ) 