#!/usr/bin/env python3
"""
Descriptive Coverage Evaluator for Schema Induction Pipeline

This script evaluates how well the codes assigned by build_corpus cover
all aspects portrayed in each datachunk by calling Qwen 32B via the server 
specified in .env. Uses async/parallel processing for optimal performance.

Automatically loads server URLs and model names from .env file.

Usage:
    python descriptive_coverage_eval.py --results_path path/to/reusability_results.json --test_data path/to/test.csv --output coverage_results.json
"""

import asyncio
import argparse
import json
import sys
import os
import re
import random
import pandas as pd
from pathlib import Path
from typing import List, Dict, Any, Tuple
import aiohttp
from dataclasses import dataclass
from dotenv import load_dotenv

# Load environment variables from the main pipeline .env file
load_dotenv('../../main_pipeline/.env')

# Get server configuration from environment
VLLM_QWEN_32B_URL = os.getenv('VLLM_QWEN_32B_URL_2', 'http://localhost:8000')
VLLM_QWEN_32B_MODEL = os.getenv('VLLM_QWEN_32B_MODEL', 'Qwen/Qwen2.5-32B-Instruct')

# Load the coverage prompt
sys.path.append('../../evaluation/distribution_metric')
from prompt import DESCRIPTIVE_COVERAGE_PROMPT

def chunk_text(text: str, chunk_size: int = 2048, overlap: int = 200):
    """Split text into overlapping chunks based on word count"""
    words = text.split()
    if len(words) <= chunk_size:
        return [text]
    
    chunks = []
    start = 0
    
    while start < len(words):
        end = start + chunk_size
        chunk_words = words[start:end]
        chunk = " ".join(chunk_words)
        chunks.append(chunk)
        
        if end >= len(words):
            break
            
        start = end - overlap
    
    return chunks


def sample_chunks_randomly(chunk_to_codes: Dict[str, List[str]], sample_ratio: float = 0.2) -> Dict[str, List[str]]:
    """Randomly sample a percentage of chunks for evaluation"""
    total_chunks = len(chunk_to_codes)
    sample_size = max(1, int(total_chunks * sample_ratio))
    
    chunk_ids = list(chunk_to_codes.keys())
    sampled_chunk_ids = random.sample(chunk_ids, sample_size)
    
    sampled_chunks = {chunk_id: chunk_to_codes[chunk_id] for chunk_id in sampled_chunk_ids}
    
    print(f"🎲 Random sampling: {sample_size}/{total_chunks} chunks ({sample_ratio*100:.0f}%)")
    print(f"   Sampled chunks: {sorted(sampled_chunk_ids)}")
    
    return sampled_chunks


@dataclass
class ProcessingStats:
    total_chunks: int = 0
    successful_evaluations: int = 0
    failed_evaluations: int = 0
    average_score: float = 0.0
    min_score: float = 10.0
    max_score: float = 0.0

class VLLMClient:
    def __init__(self, base_url: str, timeout: int = 30):
        self.base_url = base_url
        self.timeout = timeout
        self.session = None
    
    async def __aenter__(self):
        self.session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout))
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()
    
    async def chat_completion(self, messages: List[Dict[str, str]], model: str, 
                            temperature: float = 0.1, max_tokens: int = 100) -> Dict[str, Any]:
        """Make a chat completion request to the vLLM server"""
        payload = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stream": False
        }
        
        try:
            async with self.session.post(
                f"{self.base_url}/v1/chat/completions",
                json=payload,
                headers={"Content-Type": "application/json"}
            ) as response:
                if response.status == 200:
                    return await response.json()
                else:
                    print(f"❌ HTTP {response.status}: {await response.text()}")
                    return None
        except Exception as e:
            print(f"❌ Request failed: {e}")
            return None

class DescriptiveCoverageEvaluator:
    def __init__(self):
        self.model_url = VLLM_QWEN_32B_URL
        self.model_name = VLLM_QWEN_32B_MODEL
        self.stats = ProcessingStats()

    def _clean_llm_response(self, content: str) -> str:
        """Clean LLM response to extract answer"""
        # Remove <think> and </think> tags but keep the content
        content = re.sub(r'</?think>', '', content, flags=re.IGNORECASE)
        return content.strip()
    
    def _extract_score_from_response(self, response_text: str) -> float:
        """Extract score from LLM response"""
        # Look for "Score: X" pattern
        score_match = re.search(r'Score:\s*(\d+(?:\.\d+)?)', response_text, re.IGNORECASE)
        if score_match:
            return float(score_match.group(1))
        
        # Look for just a number
        number_match = re.search(r'\b(\d+(?:\.\d+)?)\b', response_text.strip())
        if number_match:
            score = float(number_match.group(1))
            if 1 <= score <= 10:
                return score
        
        print(f"⚠️ Could not extract score from: {response_text[:100]}...")
        return 0.0
    
    def _sanitize_text(self, text: str) -> str:
        """Sanitize text to prevent HTTP 400 errors from problematic characters"""
        # Replace << and >> patterns that cause server issues
        sanitized = text.replace('<<', '[[').replace('>>', ']]')
        return sanitized
    
    async def evaluate_single_chunk(self, chunk_id: str, datapoint_text: str, codes: List[str]) -> Dict[str, Any]:
        """Evaluate coverage for a single chunk"""
        try:
            # Format codes as a list
            codes_text = ", ".join([f'"{code}"' for code in codes])
            
            # Sanitize the datapoint text to prevent HTTP 400 errors
            sanitized_datapoint_text = self._sanitize_text(datapoint_text)
            
            # Create the prompt
            prompt = DESCRIPTIVE_COVERAGE_PROMPT.format(
                document=sanitized_datapoint_text,
                keywords=codes_text
            )
            
            messages = [{"role": "user", "content": prompt}]
            
            async with VLLMClient(self.model_url, 30) as client:
                response = await client.chat_completion(
                    messages, 
                    self.model_name, 
                    temperature=0.1, 
                    max_tokens=500
                )
                
                if response and 'choices' in response:
                    content = response['choices'][0]['message']['content'].strip()
                    cleaned_content = self._clean_llm_response(content)
                    score = self._extract_score_from_response(cleaned_content)
                    
                    print(f"✅ {chunk_id}: Score {score}")
                    return {
                        "success": True,
                        "score": score,
                        "raw_response": content
                    }
                else:
                    print(f"❌ {chunk_id}: No valid response")
                    return {"success": False, "error": "No valid response"}
                    
        except Exception as e:
            print(f"❌ {chunk_id}: Error - {e}")
            return {"success": False, "error": str(e)}
    
    async def evaluate_all_chunks(self, chunk_to_codes: Dict[str, List[str]], 
                                datapoints: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Evaluate coverage for all chunks"""
        print(f"🔄 Starting descriptive coverage evaluation for {len(chunk_to_codes)} chunks")
        print(f"📡 Using server: {self.model_url}")
        print(f"🤖 Using model: {self.model_name}")
        
        # Create semaphore for concurrency control
        semaphore = asyncio.Semaphore(5)  # Limit concurrent requests
        
        async def evaluate_with_semaphore(chunk_id: str, codes: List[str]):
            async with semaphore:
                # Find the corresponding datapoint
                try:
                    # Handle different chunk ID formats
                    if 'datapoint_' in chunk_id and '_chunk_' in chunk_id:
                        # Format: datapoint_X_chunk_Y
                        datapoint_idx = int(chunk_id.split('_')[1])
                    elif 'datapoint_' in chunk_id:
                        # Format: datapoint_X
                        datapoint_idx = int(chunk_id.split('_')[1])
                    else:
                        # For test chunks or other formats, use first datapoint
                        datapoint_idx = 0
                    
                    if datapoint_idx < len(datapoints):
                        # Get the full datapoint text and chunk it
                        full_text = datapoints[datapoint_idx]['text']
                        chunks = chunk_text(full_text, chunk_size=2048, overlap=200)
                        
                        # Extract chunk index from chunk_id
                        parts = chunk_id.split('_')
                        chunk_idx = int(parts[3])
                        
                        # Use the specific chunk
                        if chunk_idx < len(chunks):
                            chunk_text_content = chunks[chunk_idx]
                        else:
                            chunk_text_content = full_text  # Fallback to full text
                        
                        return await self.evaluate_single_chunk(chunk_id, chunk_text_content, codes)
                    else:
                        return {"success": False, "error": "Datapoint index out of range"}
                except (ValueError, IndexError) as e:
                    return {"success": False, "error": f"Invalid chunk ID format: {chunk_id}"}
        
        # Process all chunks concurrently
        print("⚡ Processing chunks concurrently...")
        tasks = [
            evaluate_with_semaphore(chunk_id, codes) 
            for chunk_id, codes in chunk_to_codes.items()
        ]
        
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Process results
        chunk_results = {}
        scores = []
        
        for i, (chunk_id, codes) in enumerate(chunk_to_codes.items()):
            result = results[i]
            
            if isinstance(result, Exception):
                print(f"❌ {chunk_id}: Exception - {result}")
                chunk_results[chunk_id] = {"success": False, "error": str(result)}
                self.stats.failed_evaluations += 1
            elif isinstance(result, dict) and result.get("success", False):
                score = result["score"]
                chunk_results[chunk_id] = result
                scores.append(score)
                self.stats.successful_evaluations += 1
                self.stats.min_score = min(self.stats.min_score, score)
                self.stats.max_score = max(self.stats.max_score, score)
            else:
                chunk_results[chunk_id] = result
                self.stats.failed_evaluations += 1
        
        # Calculate statistics
        self.stats.total_chunks = len(chunk_to_codes)
        if scores:
            self.stats.average_score = sum(scores) / len(scores)
        
        return {
            "average_coverage_score": self.stats.average_score,
            "total_chunks": self.stats.total_chunks,
            "successful_evaluations": self.stats.successful_evaluations,
            "failed_evaluations": self.stats.failed_evaluations,
            "min_score": self.stats.min_score,
            "max_score": self.stats.max_score,
            "chunk_results": chunk_results
        }

def load_test_data(test_data_path: str) -> List[Dict[str, Any]]:
    """Load test data from CSV file"""
    print(f"📂 Loading test data from {test_data_path}")
    df = pd.read_csv(test_data_path)
    
    # Convert to list of dictionaries
    datapoints = []
    for idx, row in df.iterrows():
        datapoints.append({
            "index": idx,
            "text": row["text"]
        })
    
    print(f"✅ Loaded {len(datapoints)} datapoints")
    return datapoints

async def main():
    parser = argparse.ArgumentParser(description='Evaluate descriptive coverage of codes')
    parser.add_argument('--results_path', type=str, required=True, 
                       help='Path to reusability results JSON file')
    parser.add_argument('--test_data', type=str, required=True,
                       help='Path to test data CSV file')
    parser.add_argument('--output', type=str, default='coverage_results.json',
                       help='Output file for coverage results')
    parser.add_argument('--sample_ratio', type=float, default=0.2, help='Random sampling ratio (default 0.2 for 20%)')
    parser.add_argument('--no_sampling', action='store_true', help='Disable random sampling and evaluate all chunks')
    
    args = parser.parse_args()
    
    # Load reusability results
    print(f"📂 Loading reusability results from {args.results_path}")
    with open(args.results_path, 'r') as f:
        reusability_results = json.load(f)
    
    # Extract chunk to codes mapping
    chunk_to_codes = reusability_results.get('chunk_to_codes_mapping', {})
    print(f"📊 Found {len(chunk_to_codes)} chunks in results")
    
    # Apply random sampling if not disabled
    if not args.no_sampling:
        chunk_to_codes = sample_chunks_randomly(chunk_to_codes, args.sample_ratio)
    
    print(f"📊 Evaluating {len(chunk_to_codes)} chunks")
    
    # Load test data
    datapoints = load_test_data(args.test_data)
    
    # Initialize evaluator
    evaluator = DescriptiveCoverageEvaluator()
    
    # Run evaluation
    results = await evaluator.evaluate_all_chunks(chunk_to_codes, datapoints)
    
    # Save results
    print(f"💾 Saving results to {args.output}")
    with open(args.output, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Print summary
    print("\n🎯 Coverage Evaluation Complete - Summary")
    print("=" * 50)
    print(f"�� Average Coverage Score: {results['average_coverage_score']:.2f}/10")
    print(f"📋 Total Chunks: {results['total_chunks']}")
    print(f"✅ Successful Evaluations: {results['successful_evaluations']}")
    print(f"❌ Failed Evaluations: {results['failed_evaluations']}")
    print(f"📈 Score Range: {results['min_score']:.1f} - {results['max_score']:.1f}")

if __name__ == '__main__':
    asyncio.run(main())
