#!/usr/bin/env python3
"""
Final Descriptive Fitness Evaluator for Schema Induction Pipeline

This script evaluates how well the codes assigned by build_corpus describe
each datachunk by calling Qwen 32B via the server specified in .env.
Uses async/parallel processing for optimal performance.

Automatically loads server URLs and model names from .env file.

Usage:
    python descriptive_fitness_evaluator_final.py --results_path path/to/reusability_results.json --test_data path/to/test.csv --output fitness_results.json
"""

import asyncio
import argparse
import json
import sys
import os
import re
import random
import pandas as pd
from pathlib import Path
from typing import List, Dict, Any, Tuple
import aiohttp
from dataclasses import dataclass
from dotenv import load_dotenv

# Load environment variables from the main pipeline .env file
load_dotenv('../../main_pipeline/.env')

def chunk_text(text: str, chunk_size: int = 2048, overlap: int = 200):
    """Split text into overlapping chunks based on word count"""
    words = text.split()
    if len(words) <= chunk_size:
        return [text]
    
    chunks = []
    start = 0
    
    while start < len(words):
        end = start + chunk_size
        chunk_words = words[start:end]
        chunk = " ".join(chunk_words)
        chunks.append(chunk)
        
        if end >= len(words):
            break
            
        start = end - overlap
    
    return chunks

def sample_chunks_randomly(chunk_to_codes: Dict[str, List[str]], sample_ratio: float = 0.2) -> Dict[str, List[str]]:
    """Randomly sample a percentage of chunks for evaluation"""
    total_chunks = len(chunk_to_codes)
    sample_size = max(1, int(total_chunks * sample_ratio))
    
    chunk_ids = list(chunk_to_codes.keys())
    sampled_chunk_ids = random.sample(chunk_ids, sample_size)
    
    sampled_chunks = {chunk_id: chunk_to_codes[chunk_id] for chunk_id in sampled_chunk_ids}
    
    print(f"🎲 Random sampling: {sample_size}/{total_chunks} chunks ({sample_ratio*100:.0f}%)")
    print(f"   Sampled chunks: {sorted(sampled_chunk_ids)}")
    
    return sampled_chunks



@dataclass
class FitnessStats:
    """Statistics for fitness evaluation"""
    total_chunks: int = 0
    successful_evaluations: int = 0
    failed_evaluations: int = 0
    total_score: float = 0.0
    min_score: float = 10.0
    max_score: float = 0.0

class DescriptiveFitnessEvaluator:
    """Evaluates descriptive fitness of codes for datachunks"""
    
    def __init__(self, server_url: str = None, model_name: str = None):
        # Load from environment variables if not provided
        self.server_url = server_url or os.getenv('VLLM_QWEN_32B_URL_2', 'http://localhost:8000')
        self.model_name = model_name or os.getenv('VLLM_QWEN_32B_MODEL', 'Qwen/Qwen3-32B')
        
        self.semaphore = asyncio.Semaphore(10)  # Limit concurrent requests
        self.stats = FitnessStats()
        
        # Import the prompt
        sys.path.append('../../evaluation/distribution_metric')
        from prompt import DESCRIPTIVE_FITNESS_PROMPT
        self.prompt_template = DESCRIPTIVE_FITNESS_PROMPT
    
    def _sanitize_text(self, text: str) -> str:
        """Sanitize text to prevent HTTP 400 errors from problematic characters"""
        # Replace << and >> patterns that cause server issues
        sanitized = text.replace('<<', '[[').replace('>>', ']]')
        return sanitized
    
    async def _call_llm(self, session: aiohttp.ClientSession, document: str, keywords: List[str]) -> Tuple[bool, float]:
        """Call LLM to evaluate descriptive fitness"""
        async with self.semaphore:
            try:
                # Sanitize the document to prevent HTTP 400 errors
                sanitized_document = self._sanitize_text(document)
                
                # Format the prompt
                prompt = self.prompt_template.format(
                    document=sanitized_document,
                    keywords=keywords
                )
                
                # Prepare the request
                payload = {
                    "model": self.model_name,
                    "messages": [
                        {"role": "user", "content": prompt}
                    ],
                    "temperature": 0.1,
                    "max_tokens": 50
                }
                
                # Make the request
                async with session.post(
                    f"{self.server_url}/v1/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=30)
                ) as response:
                    if response.status == 200:
                        result = await response.json()
                        content = result['choices'][0]['message']['content'].strip()
                        
                        # Clean the response (remove thinking tags if any)
                        content = self._clean_response(content)
                        
                        # Extract score
                        score = self._extract_score(content)
                        if score is not None:
                            return True, score
                        else:
                            print(f"❌ Failed to extract score from: {content}")
                            return False, 0.0
                    else:
                        print(f"❌ HTTP {response.status}: {await response.text()}")
                        return False, 0.0
                        
            except Exception as e:
                print(f"❌ Error calling LLM: {e}")
                return False, 0.0
    
    def _clean_response(self, response: str) -> str:
        """Clean LLM response by removing thinking tags and extra text"""
        # Remove thinking tags
        response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
        response = re.sub(r'</think>.*?<think>', '', response, flags=re.DOTALL)
        
        # Remove common prefixes
        response = re.sub(r'^(Score:|The score is|I would rate this|Rating:)\s*', '', response, flags=re.IGNORECASE)
        
        # Clean whitespace
        response = response.strip()
        
        return response
    
    def _extract_score(self, content: str) -> float:
        """Extract numerical score from LLM response"""
        # Look for patterns like "Score: 7", "7", "7/10", etc.
        patterns = [
            r'Score:\s*(\d+(?:\.\d+)?)',
            r'(\d+(?:\.\d+)?)/10',
            r'^(\d+(?:\.\d+)?)$',
            r'(\d+(?:\.\d+)?)'
        ]
        
        for pattern in patterns:
            match = re.search(pattern, content)
            if match:
                score = float(match.group(1))
                if 1 <= score <= 10:
                    return score
        
        return None
    
    async def evaluate_chunk(self, session: aiohttp.ClientSession, chunk_text: str, codes: List[str]) -> Tuple[bool, float]:
        """Evaluate fitness for a single chunk"""
        if not codes:
            return False, 0.0
        
        success, score = await self._call_llm(session, chunk_text, codes)
        
        if success:
            self.stats.successful_evaluations += 1
            self.stats.total_score += score
            self.stats.min_score = min(self.stats.min_score, score)
            self.stats.max_score = max(self.stats.max_score, score)
        else:
            self.stats.failed_evaluations += 1
        
        self.stats.total_chunks += 1
        
        return success, score
    
    async def evaluate_all_chunks(self, chunk_to_codes: Dict[str, List[str]], 
                                datapoints: List[str]) -> Dict[str, Any]:
        """Evaluate fitness for all chunks with parallel processing"""
        print(f"🔄 Starting descriptive fitness evaluation for {len(chunk_to_codes)} chunks")
        print(f"📡 Using server: {self.server_url}")
        print(f"🤖 Using model: {self.model_name}")
        
        # Create HTTP session
        async with aiohttp.ClientSession() as session:
            # Prepare evaluation tasks
            tasks = []
            chunk_results = {}
            
            for chunk_id, codes in chunk_to_codes.items():
                # Extract datapoint index and chunk index from chunk_id (format: datapoint_X_chunk_Y)
                parts = chunk_id.split('_')
                datapoint_idx = int(parts[1])
                chunk_idx = int(parts[3])
                
                # Get the full datapoint text and chunk it
                full_text = datapoints[datapoint_idx]
                chunks = chunk_text(full_text, chunk_size=2048, overlap=200)
                
                # Use the specific chunk
                if chunk_idx < len(chunks):
                    chunk_text_content = chunks[chunk_idx]
                else:
                    chunk_text_content = full_text  # Fallback to full text
                
                # Create evaluation task
                task = self.evaluate_chunk(session, chunk_text_content, codes)
                tasks.append((chunk_id, task))
            
            # Execute all tasks concurrently
            print(f"⚡ Processing {len(tasks)} chunks concurrently...")
            results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
            
            # Process results
            for i, (chunk_id, result) in enumerate(zip([chunk_id for chunk_id, _ in tasks], results)):
                if isinstance(result, Exception):
                    print(f"❌ Error evaluating {chunk_id}: {result}")
                    chunk_results[chunk_id] = {"success": False, "score": 0.0, "error": str(result)}
                else:
                    success, score = result
                    chunk_results[chunk_id] = {"success": success, "score": score}
                    
                    if success:
                        print(f"✅ {chunk_id}: Score {score:.1f}")
                    else:
                        print(f"❌ {chunk_id}: Failed evaluation")
        
        # Calculate final metrics
        avg_score = self.stats.total_score / self.stats.successful_evaluations if self.stats.successful_evaluations > 0 else 0.0
        
        return {
            "average_fitness_score": avg_score,
            "total_chunks": self.stats.total_chunks,
            "successful_evaluations": self.stats.successful_evaluations,
            "failed_evaluations": self.stats.failed_evaluations,
            "min_score": self.stats.min_score if self.stats.successful_evaluations > 0 else 0.0,
            "max_score": self.stats.max_score if self.stats.successful_evaluations > 0 else 0.0,
            "chunk_results": chunk_results
        }

def load_test_data(test_data_path: str) -> List[str]:
    """Load test data from CSV file"""
    print(f"📂 Loading test data from {test_data_path}")
    df = pd.read_csv(test_data_path)
    
    # Get text column (assume 'text' or first column)
    if 'text' in df.columns:
        datapoints = df['text'].tolist()
    else:
        datapoints = df.iloc[:, 0].tolist()
    
    print(f"✅ Loaded {len(datapoints)} datapoints")
    return datapoints

async def main():
    parser = argparse.ArgumentParser(description='Evaluate descriptive fitness of codes for datachunks')
    parser.add_argument('--results_path', type=str, required=True, help='Path to reusability results JSON file')
    parser.add_argument('--test_data', type=str, required=True, help='Path to test data CSV file')
    parser.add_argument('--output', type=str, help='Path to save fitness results JSON file')
    parser.add_argument('--server_url', type=str, help='LLM server URL (overrides .env)')
    parser.add_argument('--model', type=str, help='Model name (overrides .env)')
    parser.add_argument('--sample_ratio', type=float, default=0.2, help='Random sampling ratio (default 0.2 for 20%)')
    parser.add_argument('--no_sampling', action='store_true', help='Disable random sampling and evaluate all chunks')
    
    args = parser.parse_args()
    
    # Load reusability results
    print(f"📂 Loading results from {args.results_path}")
    with open(args.results_path, 'r') as f:
        reusability_results = json.load(f)
    
    # Load test data
    datapoints = load_test_data(args.test_data)
    
    # Extract chunk mapping
    chunk_to_codes = reusability_results.get('chunk_to_codes_mapping', {})
    
    print(f"📊 Found {len(chunk_to_codes)} chunks in results")
    
    # Apply random sampling if not disabled
    if not args.no_sampling:
        chunk_to_codes = sample_chunks_randomly(chunk_to_codes, args.sample_ratio)
    
    print(f"📊 Evaluating {len(chunk_to_codes)} chunks")
    
    # Initialize evaluator (automatically loads from .env)
    evaluator = DescriptiveFitnessEvaluator(
        server_url=args.server_url,
        model_name=args.model
    )
    
    # Run evaluation
    fitness_results = await evaluator.evaluate_all_chunks(chunk_to_codes, datapoints)
    
    # Print summary
    print(f"\n🎯 Descriptive Fitness Results:")
    print(f"   Average Score: {fitness_results['average_fitness_score']:.2f}/10")
    print(f"   Total Chunks: {fitness_results['total_chunks']}")
    print(f"   Successful: {fitness_results['successful_evaluations']}")
    print(f"   Failed: {fitness_results['failed_evaluations']}")
    print(f"   Score Range: {fitness_results['min_score']:.1f} - {fitness_results['max_score']:.1f}")
    
    # Save results
    if args.output:
        print(f"\n💾 Saving results to {args.output}")
        with open(args.output, 'w') as f:
            json.dump(fitness_results, f, indent=2)
        print(f"✅ Results saved successfully!")
    
    return fitness_results

if __name__ == '__main__':
    asyncio.run(main())
