#%%
import os
import re
import json
import time
import random
import argparse
import logging
import requests
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional, Union

import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from vllm import LLM, SamplingParams
from tqdm.auto import tqdm

from utils.config import CONFIG
import torch
from sentence_transformers import SentenceTransformer


from typing import List, Dict, Any, Tuple
import json
import time
import threading
from queue import Queue


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Config:
    WIKI_ENDPOINT = CONFIG['WIKI_ENDPOINT']
    OPENROUTER_URL = CONFIG['OPENROUTER_URL']
    OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY", CONFIG['OPENROUTER_KEY'])
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

    # Search behavior
    TOP_K = 10
    MAX_SRLIMIT = 10
    RELEV_TH = 0.30
    PASSLEN = 5
    WAIT_REQ = 0.3
    RETRY = 3
    WTIME = 25
    TIMEOUT = 120

    MIN_TOTAL_15 = 11
    MAX_EXTRA_SEARCH = 0 

config = Config()

class PromptTemplates:

    AMBIGUITY_DETECTION = """You are an expert at analyzing query ambiguity.
Your task is to determine if a query is ambiguous and to classify the ambiguity type.

Analyze the following query and decide:
1) reasoning.
1) Is the query ambiguous?
2) What specific aspects make it ambiguous?
3) What extra information would clarify it?
4) Classify the ambiguity as one of: "syntactic", "general", "semantic", or "none".

Query: {query}

Return STRICT JSON:
{{
  "reasoning": "string",
  "is_ambiguous": true/false,
  "ambiguity_type": "syntactic" | "general" | "semantic" |"none",
  "ambiguous_aspects": ["..."],
  "clarification_needed": "string",
}}

Definitions:
- syntactic: The sentence permits multiple plausible grammatical parses (attachment/scope/coordination/pronoun reference).
- general: The query is overspecific; a broader, closely related formulation would better capture the user's true information need.
- semantic: The input itself is clear in syntax but underspecified in meaning, allowing multiple valid interpretations at the level of world knowledge, concepts, or intent.
"""

    QUERY_CLARIFICATION = """You are an expert at clarifying ambiguous queries.
Given the original query and an ambiguity analysis, rewrite the query into TWO possible clarified versions. 
Each version must be specific, actionable, and faithful to a plausible intent.

Original Query: {query}
Ambiguity Analysis (JSON): {analysis}

Write STRICT JSON:
{
  "reasoning": "why these clarifications resolve the ambiguity",
  "clarified_query1": "string",
  "clarified_query2": "string",
}
"""

    REACT_AGENT = """You are a research assistant following ReAct (Reasoning, Acting, Observing).

Available Actions:
- SEARCH[query]  → run a search using the configured method
- ANSWER[text]   → provide a final answer now

Constraints:
- Max searches allowed: {max_searches}
- Searches used so far: {current_searches}
- Do NOT reuse the exact same search query as previously used in context.

Task Query: {query}

Previous Context:
{context}

Instructions:
1) THINK about the next best step.
2) If more evidence is needed, choose SEARCH[very specific query].
3) If sufficient, choose ANSWER[concise, well-supported answer].
4) If you have already reached the maximum allowed searches, you MUST output ANSWER[...] now.

Respond in EXACT format:
THOUGHT: <your internal reasoning, one short paragraph>
ACTION: SEARCH[...specific query...]  OR  ACTION: ANSWER[...final answer...]
"""

    INFORMATION_SYNTHESIS = """You are an expert at synthesizing information from multiple sources.

Original Task Query: {query}
Search Results Summary (JSON): {search_results}

Write a comprehensive, well-structured answer that:
- Directly addresses the query
- Uses key facts from the provided passages
- Notes any conflicts or uncertainties
- Is concise and clear

Return only the answer text (no JSON).
"""

class WikipediaSearcher:
    def __init__(self):
        self.base_url = config.WIKI_ENDPOINT
        self.headers = {"User-Agent": "agentic-rag/1.1"}

    def _wiki_request(self, params: Dict[str, Any]) -> Dict[str, Any]:
        base_params = {"action": "query", "format": "json", "utf8": 1}
        full_params = {**base_params, **params}

        for attempt in range(config.RETRY):
            try:
                resp = requests.get(self.base_url, params=full_params,
                                    headers=self.headers, timeout=config.WTIME)
                if resp.status_code == 429:
                    wait_time = int(resp.headers.get("Retry-After", 2**attempt))
                    time.sleep(wait_time + random.random())
                    continue
                resp.raise_for_status()
                return resp.json()
            except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout):
                time.sleep(min(60, 2**attempt) + random.random())

        raise RuntimeError("Wikipedia request exceeded maximum retries")

    def search(self, query: str, limit: int = None) -> List[Dict]:
        limit = min(limit or config.TOP_K, config.MAX_SRLIMIT)
        params = {"list": "search", "srsearch": query, "srlimit": limit}
        result = self._wiki_request(params)
        return result.get("query", {}).get("search", [])

    def get_content(self, page_id: int) -> str:
        params = {
            "pageids": str(page_id),
            "prop": "extracts",
            "exintro": False,
            "explaintext": True
        }
        result = self._wiki_request(params)
        pages = result.get("query", {}).get("pages", {})
        return pages.get(str(page_id), {}).get("extract", "")

class FAISSSearcher:
    def __init__(self,
                 sentence_model: Union[str, SentenceTransformer],
                 index_path: Optional[str] = None,
                 documents_path: Optional[str] = None):
        if isinstance(sentence_model, SentenceTransformer):
            self.sentence_model = sentence_model
        else:
            self.sentence_model = SentenceTransformer(sentence_model, device=str(config.DEVICE))

        self.sentence_model.max_seq_length = 512
        self.sentence_model.eval()

        self.index = None
        self.documents = []
        if index_path and documents_path:
            self.load_index(index_path, documents_path)

    def load_index(self, index_path: str, documents_path: str):
        self.index = faiss.read_index(index_path)
        with open(documents_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        self.documents = meta["documents"]

    def search(self, query: str, k: int = 10) -> List[Dict]:
        if self.index is None:
            raise ValueError("FAISS index not loaded")
        query_embedding = self.sentence_model.encode([query], convert_to_tensor=True, device=config.DEVICE)
        query_embedding = query_embedding.cpu().numpy()
        scores, indices = self.index.search(query_embedding, k)

        out = []
        for score, idx in zip(scores[0], indices[0]):
            if 0 <= idx < len(self.documents):
                out.append({
                    "score": float(score),
                    "content": self.documents[idx]["passage"],
                    "index": int(idx)
                })
        return out


logger = logging.getLogger(__name__)

class BatchModelInterface:
    CALL_COUNT: int = 0
    INDIVIDUAL_COUNT: Dict[str, int] = {}
    
    def __init__(self, model_type: str, model_id: str, max_length: int = 32768, 
                 batch_size: int = 10, max_concurrent: int = 5):
        self.model_type = model_type
        self.model_id = model_id
        self.max_length = max_length
        self.batch_size = batch_size
        self.max_concurrent = max_concurrent
        
        if model_type == "api":
            # Import config here to avoid circular import
            from utils.config import CONFIG
            import os
            OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY", CONFIG['OPENROUTER_KEY'])
            OPENROUTER_URL = CONFIG['OPENROUTER_URL']
            
            self.headers = {
                "Authorization": f"Bearer {OPENROUTER_KEY}",
                "Content-Type": "application/json"
            }
            self.openrouter_url = OPENROUTER_URL
        else:

            self.llm = LLM(
                model=model_id,
                tensor_parallel_size=torch.cuda.device_count() if torch.cuda.is_available() else 1,
                max_model_len=max_length,
                gpu_memory_utilization=0.9
            )
            self.tokenizer = self.llm.get_tokenizer()

    def _extract_content_from_message(self, message: Any) -> str:
        """Extract content from OpenRouter message, handling different formats"""
        if isinstance(message, str):
            return message
        
        if not isinstance(message, dict):
            return str(message) if message is not None else ""
        
        content_keys = ["content", "reasoning", "text", "response"]
        for key in content_keys:
            if key in message and message[key] is not None:
                content = message[key]
                if isinstance(content, str):
                    return content
                elif isinstance(content, dict):
                    return self._extract_content_from_message(content)
                else:
                    return str(content)
        
        return str(message)

    def generate(self, messages: List[Dict], temperature: float = 0.7, max_tokens: int = 32768) -> Tuple[str, Dict]:
        """Single generation (backward compatibility)"""
        if self.model_type == "api":
            return self._generate_api_single(messages, temperature, max_tokens)
        else:
            results = self._generate_batch_vllm([{"messages": messages}], temperature, max_tokens)
            return results[0] if results else ("", {})

    def _generate_api_single(self, messages: List[Dict], temperature: float, max_tokens: int) -> Tuple[str, Dict]:
        """Single API call"""
        payload = {
            "model": self.model_id,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        
        RETRY = 3
        TIMEOUT = 120
        
        for attempt in range(RETRY):
            try:
                resp = requests.post(self.openrouter_url, headers=self.headers, 
                                   json=payload, timeout=TIMEOUT)
                if resp.status_code == 429:
                    wait_time = int(resp.headers.get("Retry-After", 2**attempt))
                    time.sleep(wait_time + random.random())
                    continue
                
                resp.raise_for_status()
                result = resp.json()
                
                BatchModelInterface.CALL_COUNT += 1
                BatchModelInterface.INDIVIDUAL_COUNT.setdefault(self.model_id, 0)
                BatchModelInterface.INDIVIDUAL_COUNT[self.model_id] += 1
                
                # Handle different OpenRouter output formats
                message = result["choices"][0]["message"]
                content = self._extract_content_from_message(message)
                return content, None
                
            except Exception as e:
                logger.warning(f"API call attempt {attempt + 1} failed: {e}")
                if attempt == RETRY - 1:
                    return f"API Error: {str(e)}", {}
                time.sleep(min(60, 2**attempt) + random.random())
        
        return "Max retries exceeded", {}

    def generate_batch_sync(self, batch_requests: List[Dict], 
                           temperature: float = 0.7, max_tokens: int = 32768) -> List[Tuple[str, Dict]]:
        """Synchronous batch generation using threading for API calls"""
        if self.model_type != "api":
            return self._generate_batch_vllm(batch_requests, temperature, max_tokens)
        
        def worker(request_queue: Queue, result_queue: Queue, temperature: float, max_tokens: int):
            while True:
                try:
                    idx, request = request_queue.get(timeout=1)
                    if request is None:  # Sentinel value to stop worker
                        break
                    
                    try:
                        content, _ = self._generate_api_single(request["messages"], temperature, max_tokens)
                        result_queue.put((idx, content, None))
                    except Exception as e:
                        result_queue.put((idx, "", {}, e))
                    finally:
                        request_queue.task_done()
                except:
                    break
        
        request_queue = Queue()
        result_queue = Queue()
        
        for i, request in enumerate(batch_requests):
            request_queue.put((i, request))
        
        workers = []
        for _ in range(min(self.max_concurrent, len(batch_requests))):
            t = threading.Thread(target=worker, args=(request_queue, result_queue, temperature, max_tokens))
            t.start()
            workers.append(t)
        
        results = [None] * len(batch_requests)
        for _ in range(len(batch_requests)):
            idx, content, error = result_queue.get()
            if error:
                logger.error(f"Request {idx} failed: {error}")
                results[idx] = ("", {})
            else:
                results[idx] = (content)
        
        for _ in workers:
            request_queue.put((None, None))
        
        for t in workers:
            t.join()
        
        return results

    def _generate_batch_vllm(self, batch_requests: List[Dict], 
                           temperature: float, max_tokens: int) -> List[Tuple[str, Dict]]:
        """Batch generation for vLLM (local models)"""
        from vllm import SamplingParams
        
        prompts = [self._messages_to_prompt(req["messages"]) for req in batch_requests]
        sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, stop=None)
        
        outputs = self.llm.generate(prompts, sampling_params)
        results = []
        
        for output in outputs:
            content = output.outputs[0].text.strip()
            results.append((content))
            
            BatchModelInterface.CALL_COUNT += 1
            BatchModelInterface.INDIVIDUAL_COUNT.setdefault(self.model_id, 0)
            BatchModelInterface.INDIVIDUAL_COUNT[self.model_id] += 1
        
        return results

    def _messages_to_prompt(self, messages: List[Dict]) -> str:
        """Convert messages to prompt format for vLLM"""
        parts = []
        for m in messages:
            role, content = m["role"], m["content"]
            if role == "system":
                parts.append(f"System: {content}\n")
            elif role == "user":
                parts.append(f"Human: {content}\n")
            elif role == "assistant":
                parts.append(f"Assistant: {content}\n")
        parts.append("Assistant: ")
        return "".join(parts)

class BatchAgenticRAGSystem:
    """Extended RAG system with batch processing capabilities"""
    
    def __init__(self, web_or_faiss: str, sentence_transformer: str, tr_or_api: str,
                 model_id: str, max_length: int = 32768, n_search: int = 5,
                 batch_size: int = 10, max_concurrent: int = 5, use_async: bool = True,
                 faiss_index_path: str = None, faiss_documents_path: str = None):

        self.search_method = web_or_faiss
        self.n_search = n_search
        self.batch_size = batch_size
        self.use_async = use_async
        self.current_searches = 0
        
        DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.sentence_model = SentenceTransformer(sentence_transformer, device=str(DEVICE))
        self.sentence_model.max_seq_length = 512
        self.sentence_model.eval()

        if web_or_faiss == "web":
            self.searcher = WikipediaSearcher()
        else:
            self.searcher = FAISSSearcher(self.sentence_model, faiss_index_path, faiss_documents_path)

        self.model = BatchModelInterface(
            tr_or_api, 
            model_id, 
            max_length,
            batch_size=batch_size,
            max_concurrent=max_concurrent
        )

        self.search_results = []
        self.context_history = []
        self.used_search_queries = set()

    def compute_relevance(self, query: str, text: str) -> float:
        """Compute relevance between query and text"""
        try:
            import torch
            with torch.no_grad():
                embs = self.sentence_model.encode([query, text], convert_to_tensor=True)
                sim = torch.nn.functional.cosine_similarity(embs[0], embs[1], dim=0)
                return float(sim)
        except:
            return 0.0

    def extract_passages(self, content: str, query: str, max_passages: int = 3) -> List[str]:
        """Extract relevant passages from content"""
        if not isinstance(content, str) or not content.strip():
            return []
        
        sentences = re.split(r'[.!?]+', content)
        sentences = [s.strip() for s in sentences if len(s.strip()) > 20]

        scored = []
        RELEV_TH = 0.30
        PASSLEN = 5
        
        for i, sent in enumerate(sentences):
            score = self.compute_relevance(query, sent)
            if score >= RELEV_TH:
                start = max(0, i - PASSLEN // 2)
                end = min(len(sentences), i + PASSLEN // 2 + 1)
                passage = " ".join(sentences[start:end])
                scored.append((score, passage))
        
        scored.sort(key=lambda x: x[0], reverse=True)
        return [p for _, p in scored[:max_passages]]

    def analyze_ambiguity(self, query: str) -> Dict:
        """Analyze ambiguity with better error handling"""
        AMBIGUITY_DETECTION = """You are an expert at analyzing query ambiguity.
Your task is to determine if a query is ambiguous and to classify the ambiguity type.

Analyze the following query and decide:
1) reasoning.
1) Is the query ambiguous?
2) What specific aspects make it ambiguous?
3) What extra information would clarify it?
4) Classify the ambiguity as one of: "syntactic", "general", "semantic", or "none".

Query: {query}

Return STRICT JSON:
{{
  "reasoning": "string",
  "is_ambiguous": true/false,
  "ambiguity_type": "syntactic" | "general" | "semantic" |"none",
  "ambiguous_aspects": ["..."],
  "clarification_needed": "string",
}}

Definitions:
- syntactic: The sentence permits multiple plausible grammatical parses (attachment/scope/coordination/pronoun reference).
- general: The query is overspecific; a broader, closely related formulation would better capture the user's true information need.
- semantic: The input itself is clear in syntax but underspecified in meaning, allowing multiple valid interpretations at the level of world knowledge, concepts, or intent.
"""
        
        messages = [
            {"role": "system", "content": "You are an expert query analyzer."},
            {"role": "user", "content": AMBIGUITY_DETECTION.format(query=query)}
        ]
        
        try:
            resp, _ = self.model.generate(messages, temperature=0.1, max_tokens=30000)
            
            if not isinstance(resp, str):
                resp = str(resp) if resp is not None else ""
                
            data = json.loads(resp)
            
            if "is_ambiguous" not in data:
                data["is_ambiguous"] = False
            if "ambiguity_type" not in data:
                data["ambiguity_type"] = "none"
                
            return data
            
        except json.JSONDecodeError as e:
            logger.warning(f"Failed to parse JSON response for ambiguity analysis: {e}")
            return {
                "is_ambiguous": False,
                "ambiguity_type": "none",
                "ambiguous_aspects": [],
                "clarification_needed": "",
                "reasoning": "JSON parsing failed"
            }
        except Exception as e:
            logger.error(f"Error in analyze_ambiguity: {e}")
            return {
                "is_ambiguous": False,
                "ambiguity_type": "none",
                "ambiguous_aspects": [],
                "clarification_needed": "",
                "reasoning": f"Analysis failed: {str(e)}"
            }

    def clarify_query(self, query: str, analysis: Dict) -> str:
        """Clarify query with better error handling"""
        if not analysis.get("is_ambiguous", False):
            return query
            
        QUERY_CLARIFICATION = """You are an expert at clarifying ambiguous queries.
Given the original query and an ambiguity analysis, rewrite the query into TWO possible clarified versions. 
Each version must be specific, actionable, and faithful to a plausible intent.

Original Query: {query}
Ambiguity Analysis (JSON): {analysis}

Write STRICT JSON:
{{
  "reasoning": "why these clarifications resolve the ambiguity",
  "clarified_query1": "string",
  "clarified_query2": "string",
}}
"""
        
        messages = [
            {"role": "system", "content": "You are an expert at clarifying ambiguous queries."},
            {"role": "user", "content": QUERY_CLARIFICATION.format(
                query=query, analysis=json.dumps(analysis, ensure_ascii=False)
            )}
        ]

        try:
            clarified_response, _ = self.model.generate(messages, temperature=0.2, max_tokens=30000)
            
            if not isinstance(clarified_response, str):
                clarified_response = str(clarified_response) if clarified_response is not None else ""
            
            try:
                j = json.loads(clarified_response)
                return j.get("clarified_query1", j.get("clarified_query", query))
            except json.JSONDecodeError:
                return clarified_response if clarified_response.strip() else query
            
        except Exception as e:
            logger.error(f"Error in clarify_query: {e}")
            return query

    def perform_search(self, query: str) -> List[Dict]:
        """Perform search with error handling"""
        self.current_searches += 1
        self.used_search_queries.add(query.strip().lower())

        try:
            if self.search_method == "web":
                return self._perform_web_search(query)
            else:
                return self._perform_faiss_search(query)
        except Exception as e:
            logger.error(f"Search failed for query '{query}': {e}")
            return []

    def _perform_web_search(self, query: str) -> List[Dict]:
        """Perform Wikipedia search"""
        raw = self.searcher.search(query, limit=10)
        detailed = []
        WAIT_REQ = 0.3
        
        for r in raw[:5]:
            try:
                content = self.searcher.get_content(r["pageid"])
                if content:
                    passages = self.extract_passages(content, query)
                    detailed.append({
                        "title": r["title"],
                        "page_id": r["pageid"],
                        "snippet": r.get("snippet", ""),
                        "passages": passages,
                        "full_content": content[:2000]
                    })
                time.sleep(WAIT_REQ)
            except Exception as e:
                logger.warning(f"Failed to get content for page {r.get('pageid', 'unknown')}: {e}")
                
        return detailed

    def _perform_faiss_search(self, query: str) -> List[Dict]:
        """Perform FAISS search"""
        raw = self.searcher.search(query, k=10)
        detailed = []
        
        for r in raw:
            try:
                passages = self.extract_passages(r["content"], query)
                detailed.append({
                    "content": r["content"][:2000],
                    "score": r["score"],
                    "index": r["index"],
                    "passages": passages
                })
            except Exception as e:
                logger.warning(f"Failed to process FAISS result: {e}")
                
        return detailed

    def process_queries_batch(self, queries: List[str]) -> List[str]:
        """Process multiple queries in batch"""
        results = []
        
        # Process in chunks
        for i in tqdm(range(0, len(queries), self.batch_size)):
            batch_queries = queries[i:i + self.batch_size]
            logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(queries)-1)//self.batch_size + 1}")
            
            batch_results = self._process_batch_chunk(batch_queries)
            results.extend(batch_results)
            
            if i + self.batch_size < len(queries):
                time.sleep(0.5)
        
        return results

    def _process_batch_chunk(self, queries: List[str]) -> List[str]:
        """Process a single batch chunk"""
        try:
            logger.info(f"Step 1: Analyzing ambiguity for {len(queries)} queries")
            ambiguity_requests = []
            for query in queries:
                if not isinstance(query, str):
                    query = str(query) if query is not None else ""
                    
                messages = [
                    {"role": "system", "content": "You are an expert query analyzer."},
                    {"role": "user", "content": f"Analyze this query for ambiguity: {query}"}
                ]
                ambiguity_requests.append({"messages": messages, "query": query})
            
            ambiguity_results = self.model.generate_batch_sync(ambiguity_requests, temperature=0.1, max_tokens=30000)
            
            logger.info("Step 2: Parsing ambiguity results")
            analyses = []
            clarified_queries = []
            
            for i, result in enumerate(ambiguity_results):
                try:
                    analysis = {
                        "is_ambiguous": False,
                        "ambiguity_type": "none",
                        "ambiguous_aspects": [],
                        "clarification_needed": "",
                    }
                    analyses.append(analysis)
                    clarified_queries.append(queries[i]) 
                    
                except Exception as e:
                    logger.warning(f"Failed to parse ambiguity result {i}: {e}")
                    analyses.append({
                        "is_ambiguous": False,
                        "ambiguity_type": "none",
                        "ambiguous_aspects": [],
                        "clarification_needed": "",
                    })
                    clarified_queries.append(queries[i])
            
            logger.info("Step 3: Processing individual queries")
            final_results = []
            
            for i, (original_query, clarified_query, analysis) in enumerate(zip(queries, clarified_queries, analyses)):
                try:
                    
                    self.current_searches = 0
                    self.search_results = []
                    self.context_history = []
                    self.used_search_queries = set()
                    
                    result = self._process_single_query_simple(original_query)
                    final_results.append(result)
                    
                except Exception as e:
                    logger.error(f"Failed to process query {i}: {e}")
                    final_results.append("")
            
            return final_results
            
        except Exception as e:
            logger.error(f"Batch processing failed: {e}")
            return [""] * len(queries)

    def _process_single_query_simple(self, query: str) -> str:
        """Simple query processing without complex ReAct loop"""
        try:
            if not isinstance(query, str):
                query = str(query) if query is not None else ""
                
            if not query.strip():
                return "Empty query"
            
            search_results = self.perform_search(query)
            
            if not search_results:
                return "No search results found"
            
            all_passages = []
            for result in search_results[:10]:  
                if self.search_method == "web":
                    passages = result.get("passages", [])
                    all_passages.extend(passages)
                else:
                    passages = result.get("passages", [])
                    all_passages.extend(passages)
            
            if not all_passages:
                return "No relevant passages found"
            
            top_content = " ".join(all_passages[:10]) 
            
            messages = [
                {"role": "system", "content": "You are a helpful assistant. Answer the query based on the provided information."},
                {"role": "user", "content": f"Query: {query}\n\nInformation: {top_content}\n\nAnswer:"}
            ]
            
            response, _ = self.model.generate(messages, temperature=0.3, max_tokens=30000)
            return response if isinstance(response, str) else str(response)
            
        except Exception as e:
            logger.error(f"Simple processing failed: {e}")
            return f"Processing error: {str(e)}"

def main_batch():
    import argparse
    from pathlib import Path
    
    parser = argparse.ArgumentParser(description="Batch Agentic RAG with OpenRouter")
    parser.add_argument("--web_or_faiss", choices=["web", "faiss"], required=True)
    parser.add_argument("--sentence_transformer", required=True)
    parser.add_argument("--tr_or_api", choices=["transformer", "api"], required=True)
    parser.add_argument("--model_id", required=True)
    parser.add_argument("--max_length", type=int, default=30000)
    parser.add_argument("--n_search", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for processing")
    parser.add_argument("--max_concurrent", type=int, default=1, help="Max concurrent requests")
    parser.add_argument("--use_async", action="store_true", help="Use async processing")
    parser.add_argument("--faiss_index_path")
    parser.add_argument("--faiss_documents_path")
    parser.add_argument("--dataset_path", default="dataset/MIRAGE.jsonl")
    parser.add_argument("--output", default="")
    args = parser.parse_args()

    if args.web_or_faiss == "faiss" and (not args.faiss_index_path or not args.faiss_documents_path):
        parser.error("--faiss_index_path and --faiss_documents_path are required for faiss mode")

    try:
        from utils.utils import load_jsonl
        dataset = load_jsonl(args.dataset_path)

        queries = [data.get("original_query", "") for data in dataset]
        
        logger.info(f"Loaded {len(queries)} queries from dataset")

        system = BatchAgenticRAGSystem(
            web_or_faiss=args.web_or_faiss,
            sentence_transformer=args.sentence_transformer,
            tr_or_api=args.tr_or_api,
            model_id=args.model_id,
            max_length=args.max_length,
            n_search=args.n_search,
            batch_size=args.batch_size,
            max_concurrent=args.max_concurrent,
            use_async=args.use_async,
            faiss_index_path=args.faiss_index_path,
            faiss_documents_path=args.faiss_documents_path
        )

        if not args.output:
            suffix = "_web" if args.web_or_faiss == "web" else ""
            model_simple = args.model_id.split("/")[-1].replace(":", "_")
            args.output = f"final_results/3_n5_k10_val_agent_{model_simple}{suffix}.jsonl"

        out_path = Path(args.output)
        out_path.parent.mkdir(parents=True, exist_ok=True)

        logger.info(f"Processing {len(queries)} queries in batches of {args.batch_size}")
        
        all_results = system.process_queries_batch(queries)
        
        with open(out_path, "w", encoding="utf-8") as fout:
            for i, (data, result) in enumerate(zip(dataset, all_results)):
                output = {
                    "qid": data.get("qid"),
                    "original_query": data.get("original_query"),
                    "rag_answer": result
                }
                fout.write(json.dumps(output, ensure_ascii=False) + "\n")
        
        logger.info(f"Completed processing. Results saved to {args.output}")
        logger.info(f"Total API calls: {BatchModelInterface.CALL_COUNT}")
        logger.info(f"Per model calls: {BatchModelInterface.INDIVIDUAL_COUNT}")
        
    except Exception as e:
        logger.error(f"Main execution failed: {e}")
        raise

if __name__ == "__main__":
    main_batch()
