#!/usr/bin/env python3
"""
Query a final corpus parquet file with a question using the local LLM.

This script loads a final corpus parquet file and sends it along with a question
to the local LLM using the same environment variables and client setup as the
schema induction pipeline.

Usage:
  python query_final_corpus.py \
    --corpus-path schema_induction_pipeline/result_storage/MATH/iteration_01/final_corpus/final_corpus.parquet \
    --question "What are the main themes in this corpus?"

Environment variables required:
  - VLLM_QWEN_32B_URL: URL for the local LLM server
  - VLLM_QWEN_32B_MODEL: Model name for the LLM
"""

import os
import json
import argparse
import asyncio
import aiohttp
import pandas as pd
from typing import List, Dict, Optional

# Load .env for environment variables if present
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception:
    pass

# Environment variables for LLM connection
VLLM_TEXT_URL = os.getenv("VLLM_QWEN_32B_URL")
DEFAULT_TEXT_MODEL = os.getenv("VLLM_QWEN_32B_MODEL")
REQUEST_TIMEOUT = int(os.getenv("VLLM_TIMEOUT", "120"))

class AsyncVLLMClient:
    def __init__(self, base_url: str, timeout: int = 120):
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.session: Optional[aiohttp.ClientSession] = None

    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        if self.session:
            await self.session.close()

    async def chat_completion(self, model: str, messages: List[Dict[str, str]], temperature: float = 0.2, max_tokens: int = 1024) -> Optional[Dict]:
        if not self.session:
            raise RuntimeError("Client session not initialized")
        payload = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        try:
            async with self.session.post(f"{self.base_url}/v1/chat/completions", json=payload, timeout=REQUEST_TIMEOUT) as resp:
                if resp.status == 200:
                    return await resp.json()
                else:
                    txt = await resp.text()
                    print(f"⚠️ chat HTTP {resp.status}: {txt}")
                    return None
        except Exception as e:
            print(f"⚠️ chat request failed: {e}")
            return None

def load_corpus(corpus_path: str) -> pd.DataFrame:
    """Load the final corpus parquet file."""
    if not os.path.exists(corpus_path):
        raise FileNotFoundError(f"Corpus file not found: {corpus_path}")
    
    print(f"📥 Loading corpus from: {corpus_path}")
    df = pd.read_parquet(corpus_path)
    print(f"✅ Loaded corpus with {len(df)} tags")
    print(f"📊 Columns: {df.columns.tolist()}")
    
    return df

def prepare_corpus_context(df: pd.DataFrame, max_tags: int = 100) -> str:
    """Prepare the corpus as context for the LLM."""
    # If there are too many tags, sample them
    if len(df) > max_tags:
        df_sample = df.sample(n=max_tags, random_state=42)
        print(f"📊 Sampling {max_tags} tags from {len(df)} total tags")
    else:
        df_sample = df
    
    # Create context from the tags
    if 'tag' in df_sample.columns:
        tags = df_sample['tag'].tolist()
    else:
        # If no 'tag' column, use the first column
        tags = df_sample.iloc[:, 0].tolist()
    
    context = "Here are the tags from the final corpus:\n\n"
    for i, tag in enumerate(tags, 1):
        context += f"{i}. {tag}\n"
    
    return context

def clean_response(response: str) -> str:
    """Clean the response by removing thinking tags and other unwanted content."""
    import re
    
    # Remove all thinking tags
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
    response = re.sub(r'<thinking>.*?</thinking>', '', response, flags=re.DOTALL)
    response = re.sub(r'<think>.*', '', response, flags=re.DOTALL)
    response = re.sub(r'<thinking>.*', '', response, flags=re.DOTALL)
    
    # Clean up extra whitespace
    response = response.strip()
    
    return response

async def query_llm_with_corpus(client: AsyncVLLMClient, question: str, corpus_context: str) -> Optional[str]:
#     """Send the question and corpus context to the LLM."""
#     prompt = f"""Analyze the following corpus of tags and provide a comprehensive analysis.

# Question: {question}

# {corpus_context}

# CRITICAL INSTRUCTION: Skip over this <think><thinking> we want nothing between these think tags. Do not use any thinking tags, reasoning markers, or internal dialogue in your response. Start immediately with your analysis covering:
# 1. Main themes and patterns
# 2. Insights from tag distribution  
# 3. Relation to the original question/domain
# 4. Most significant findings

# Begin your response directly with the analysis."""

    """Send the question and corpus context to the LLM."""
    prompt = f"""Analyze the following corpus of tags and provide a comprehensive analysis.

    Question: {question}

    {corpus_context}.
    
    It is recommended that you use exact phrase from list and use as much tags to answer all parts of the question."""

    messages = [
        {
            "role": "system", 
            "content": "You are a direct, analytical assistant. You must respond immediately with analysis without any thinking tags, reasoning markers, or internal dialogue. Never use <think>, <thinking>, or similar tags. Skip over this <think><thinking> we want nothing between these think tags. Start your response directly with the analysis."
        },
        {"role": "user", "content": prompt}
    ]
    
    print("🤖 Sending query to LLM...")
    response = await client.chat_completion(DEFAULT_TEXT_MODEL, messages, temperature=0.3, max_tokens=2048)
    
    if response:
        try:
            content = response["choices"][0]["message"]["content"]
            # Return raw response without cleaning
            # cleaned_content = clean_response(content)
            return content
        except Exception as e:
            print(f"⚠️ Failed to extract content from response: {e}")
            return None
    else:
        print("❌ No response received from LLM")
        return None

async def main():
    parser = argparse.ArgumentParser(description="Query a final corpus parquet file with a question using the local LLM.")
    parser.add_argument("--corpus-path", required=True, help="Path to the final corpus parquet file")
    parser.add_argument("--question", required=True, help="Question to ask about the corpus")
    parser.add_argument("--max-tags", type=int, default=100, help="Maximum number of tags to include in context (default 100)")
    parser.add_argument("--output", help="Output file to save the response (optional)")
    
    args = parser.parse_args()
    
    # Check environment variables
    if not VLLM_TEXT_URL:
        raise ValueError("VLLM_QWEN_32B_URL environment variable not set")
    if not DEFAULT_TEXT_MODEL:
        raise ValueError("VLLM_QWEN_32B_MODEL environment variable not set")
    
    print(f"🔗 LLM URL: {VLLM_TEXT_URL}")
    print(f"🤖 Model: {DEFAULT_TEXT_MODEL}")
    print(f"❓ Question: {args.question}")
    
    # Load corpus
    df = load_corpus(args.corpus_path)
    
    # Prepare context
    corpus_context = prepare_corpus_context(df, args.max_tags)
    
    # Query LLM
    async with AsyncVLLMClient(VLLM_TEXT_URL, timeout=REQUEST_TIMEOUT) as client:
        response = await query_llm_with_corpus(client, args.question, corpus_context)
    
    if response:
        print("\n" + "="*80)
        print("🤖 LLM RESPONSE:")
        print("="*80)
        print(response)
        print("="*80)
        
        # Save to file if requested
        if args.output:
            with open(args.output, 'w', encoding='utf-8') as f:
                f.write(f"Question: {args.question}\n\n")
                f.write(f"Corpus: {args.corpus_path}\n\n")
                f.write("Response:\n")
                f.write(response)
            print(f"💾 Response saved to: {args.output}")
    else:
        print("❌ Failed to get response from LLM")

if __name__ == "__main__":
    asyncio.run(main())
