#!/usr/bin/env python3
"""
OPTIMIZED High-Level Code Generation Script - FIXED VERSION 4

Key optimizations:
1. Reuse existing embeddings instead of recomputing them
2. Only embed new high-level codes
3. Efficient clustering with existing embeddings
4. FIXED: Handle case where embeddings don't exist yet (first iteration)
5. FIXED: Use correct aiohttp approach for chat completions
6. FIXED: Enable concurrency for high-level code generation

This script performs the following tasks:
1. Loads existing embeddings from previous iterations (if they exist)
2. For every cluster with more than 1 node, generates a higher-level code
3. Uses LLM to generate strategic, high-level codes that capture cluster themes
4. Creates mappings of high-level codes to all datachunks in each cluster
5. Enhances the corpus with these mappings
"""

import os
import sys
import asyncio
import aiohttp
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from dotenv import load_dotenv
import argparse
from tqdm import tqdm

# Load environment variables
load_dotenv()

# Add current directory to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

# Import embedding utilities
from embeddings import AsyncVllmClient, build_embeddings_parquet_async
from cluster import Clusterer  # Use optimized clustering

class OptimizedHighLevelCodeGenerator:
    """
    OPTIMIZED: Generates high-level codes for clusters with embedding reuse
    FIXED: Handles case where embeddings don't exist yet
    FIXED: Uses correct aiohttp approach for chat completions
    FIXED: Enables concurrency for high-level code generation
    """
    
    def __init__(self, 
                 model_url: str = None,
                 model_name: str = None,
                 max_concurrency: int = 32,
                 timeout: int = 30):
        """
        Initialize the high-level code generator
        
        Args:
            model_url: URL for the vLLM server
            model_name: Name of the model to use
            max_concurrency: Maximum number of concurrent requests
            timeout: Request timeout in seconds
        """
        self.model_url = model_url or os.getenv("VLLM_QWEN_32B_URL")
        self.model_name = model_name or os.getenv("VLLM_QWEN_32B_MODEL")
        self.max_concurrency = max_concurrency
        self.timeout = timeout
        
        if not self.model_url or not self.model_name:
            raise ValueError("Model URL and name must be provided or set in environment variables")
        
        print(f"🚀 Initialized HighLevelCodeGenerator")
        print(f"   Model URL: {self.model_url}")
        print(f"   Model: {self.model_name}")
        print(f"   Max Concurrency: {self.max_concurrency}")
    
    async def load_existing_embeddings(self, embeddings_path: str) -> Tuple[pd.DataFrame, np.ndarray, List[str]]:
        """
        OPTIMIZED: Load existing embeddings instead of recomputing them
        FIXED: Handle case where embeddings don't exist yet
        
        Args:
            embeddings_path: Path to existing embeddings parquet file
            
        Returns:
            Tuple of (embeddings_df, embeddings_array, codes_list)
        """
        print("🔄 Loading existing embeddings...")
        
        if not os.path.exists(embeddings_path):
            print(f"⚠️ Embeddings file not found: {embeddings_path}")
            print("🔄 This is likely the first iteration - will generate new embeddings")
            return None, None, None
        
        # Load existing embeddings
        embeddings_df = pd.read_parquet(embeddings_path)
        print(f"✅ Loaded existing embeddings: {embeddings_df.shape}")
        
        # Extract embeddings array and codes
        if 'embedding' in embeddings_df.columns:
            # Convert list embeddings to numpy array
            embeddings_array = np.array(embeddings_df['embedding'].tolist())
            codes_list = embeddings_df['code'].tolist() if 'code' in embeddings_df.columns else embeddings_df['tag'].tolist()
        else:
            raise ValueError("No 'embedding' column found in embeddings file")
        
        print(f"   Embeddings shape: {embeddings_array.shape}")
        print(f"   Number of codes: {len(codes_list)}")
        
        return embeddings_df, embeddings_array, codes_list
    
    async def generate_embeddings_if_needed(self, corpus_df: pd.DataFrame, output_path: str) -> Tuple[pd.DataFrame, np.ndarray, List[str]]:
        """
        Generate embeddings if they don't exist yet
        
        Args:
            corpus_df: Corpus DataFrame
            output_path: Path to save embeddings
            
        Returns:
            Tuple of (embeddings_df, embeddings_array, codes_list)
        """
        print("🔄 Generating new embeddings...")
        
        # Generate embeddings using the async function
        _, _, _, embeddings_df = await build_embeddings_parquet_async(
            corpus_df=corpus_df, 
            output_parquet=output_path
        )
        
        print(f"✅ Generated embeddings: {embeddings_df.shape}")
        
        # Extract embeddings array and codes
        if 'embedding' in embeddings_df.columns:
            embeddings_array = np.array(embeddings_df['embedding'].tolist())
            codes_list = embeddings_df['code'].tolist() if 'code' in embeddings_df.columns else embeddings_df['tag'].tolist()
        else:
            raise ValueError("No 'embedding' column found in generated embeddings")
        
        return embeddings_df, embeddings_array, codes_list
    
    async def perform_optimized_clustering(self, embeddings_array: np.ndarray, output_dir: str) -> pd.DataFrame:
        """
        OPTIMIZED: Perform clustering using the optimized Clusterer class
        
        Args:
            embeddings_array: Pre-computed embeddings
            output_dir: Directory to save clustering results
            
        Returns:
            DataFrame with cluster assignments
        """
        print("🔄 Running OPTIMIZED clustering...")
        
        # Use the optimized Clusterer class
        clusterer = Clusterer()
        
        # Perform fast clustering - FIXED: cluster_fast returns a dict, not tuple
        cluster_results = clusterer.cluster_fast(embeddings_array)
        
        # Extract results from the dictionary
        cluster_assignments = cluster_results['labels']
        optimal_k = cluster_results['optimal_k']
        
        print(f"✅ Clustering completed with k={optimal_k}")
        
        # Create cluster assignments DataFrame
        cluster_df = pd.DataFrame({
            'code_index': range(len(cluster_assignments)),
            'cluster_id': cluster_assignments
        })
        
        # Save cluster assignments
        os.makedirs(output_dir, exist_ok=True)
        cluster_path = os.path.join(output_dir, "cluster_assignments.parquet")
        cluster_df.to_parquet(cluster_path, index=False)
        print(f"✅ Saved cluster assignments to {cluster_path}")
        
        return cluster_df
    
    async def generate_high_level_codes(self, 
                                      corpus_df: pd.DataFrame, 
                                      cluster_df: pd.DataFrame,
                                      codes_list: List[str],
                                      output_dir: str) -> pd.DataFrame:
        """
        FIXED: Generate high-level codes for clusters with CONCURRENCY
        
        Args:
            corpus_df: DataFrame with corpus data
            cluster_df: DataFrame with cluster assignments
            codes_list: List of codes corresponding to embeddings
            output_dir: Directory to save results
            
        Returns:
            DataFrame with high-level codes
        """
        print("🧠 Generating high-level codes with CONCURRENCY...")
        
        # Group codes by cluster
        cluster_groups = cluster_df.groupby('cluster_id')
        
        # Prepare tasks for concurrent processing
        tasks = []
        cluster_info = []
        
        for cluster_id, group in cluster_groups:
            if len(group) > 1:  # Only process clusters with multiple codes
                cluster_codes = [codes_list[idx] for idx in group['code_index']]
                
                # Create task for concurrent processing
                task = self._generate_single_high_level_code_with_semaphore(cluster_codes)
                tasks.append(task)
                cluster_info.append({
                    'cluster_id': cluster_id,
                    'cluster_codes': cluster_codes,
                    'num_codes': len(cluster_codes)
                })
        
        print(f"   Processing {len(tasks)} clusters concurrently...")
        
        # Process all clusters concurrently with semaphore
        semaphore = asyncio.Semaphore(self.max_concurrency)
        
        async def process_with_semaphore(task):
            async with semaphore:
                return await task
        
        # Execute all tasks concurrently
        results = await asyncio.gather(*[process_with_semaphore(task) for task in tasks], return_exceptions=True)
        
        # Process results
        high_level_codes = []
        for i, result in enumerate(results):
            if isinstance(result, Exception):
                print(f"❌ Error processing cluster {cluster_info[i]['cluster_id']}: {result}")
                continue
            
            if result:  # If high-level code was generated successfully
                high_level_codes.append({
                    'cluster_id': cluster_info[i]['cluster_id'],
                    'high_level_code': result,
                    'source_codes': cluster_info[i]['cluster_codes'],
                    'num_source_codes': cluster_info[i]['num_codes']
                })
        
        # Create DataFrame
        high_level_df = pd.DataFrame(high_level_codes)
        
        # Save results
        output_path = os.path.join(output_dir, "high_level_codes.parquet")
        high_level_df.to_parquet(output_path, index=False)
        print(f"✅ Generated {len(high_level_df)} high-level codes concurrently")
        print(f"✅ Saved to {output_path}")
        
        return high_level_df
    
    async def _generate_single_high_level_code_with_semaphore(self, cluster_codes: List[str]) -> Optional[str]:
        """
        Generate a single high-level code for a cluster of codes
        FIXED: Uses correct aiohttp approach for chat completions
        
        Args:
            cluster_codes: List of codes in the cluster
            
        Returns:
            Generated high-level code or None if failed
        """
        # Create prompt for high-level code generation
        prompt = self._create_high_level_prompt(cluster_codes)
        
        try:
            # FIXED: Use aiohttp directly like the original code
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.model_name,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.3,
                    "max_tokens": 512
                }
                
                async with session.post(
                    f"{self.model_url}/v1/chat/completions", 
                    json=payload, 
                    timeout=aiohttp.ClientTimeout(total=self.timeout)
                ) as response:
                    if response.status == 200:
                        result = await response.json()
                        content = result["choices"][0]["message"]["content"].strip()
                        # Extract the high-level code from response
                        high_level_code = self._extract_high_level_code(content)
                        return high_level_code
                    else:
                        print(f"❌ Request failed with status {response.status}")
                        return None
                        
        except Exception as e:
            print(f"❌ Error generating high-level code: {e}")
            return None
    
    def _create_high_level_prompt(self, cluster_codes: List[str]) -> str:
        """Create prompt for high-level code generation"""
        codes_text = "\n".join([f"- {code}" for code in cluster_codes])
        
        prompt = f"""Generate exactly 1 high-level code that represents the following cluster of medium-level codes with respect to the question asked.

High-level codes are: Passage level, global level themes, tags, or patterns. Higher semantic hierarchy, broader sense that frame the document in a wider context. More general, broad, and macro.

Tag requirements:
Semantic + Pragmatic balance: Capture what the text is about (topics, entities, facts) and how it operates (high-level logic and pattern, language or narrative style, structural organization, intended audience, functionality, purpose, communicative intent, rhetorical strategies, property of the text).

Informative & concise – each tag ≤ 15 words, no redundancy, no punctuation except hyphens. For the linguistic style, you will try to make tags a compact noun phrases or compound keywords (Twitter style). You will try to optimize the reusability and generalizability of the tags.

Descending granularity – high‑level tags must be the most general and abstract. However, tags should still be descriptive, reusable, generalizable, not purely extractive. Try your best to minimize vagueness and ambiguity.

Unique – no tag should appear in more than one tier.

Use standard English, use normal spacing, do not use Camel case. Do not add hashtag.

Medium-level codes in this cluster:
{codes_text}

Respond with JSON ONLY in this exact format:
{{"high-level": ["code1"]}}

IMPORTANT: 
- Respond with JSON ONLY. Do not include explanations, prose, or code fences.
- Do NOT use <think> tags or any thinking mode.
- Do NOT include any analysis or step-by-step reasoning.
- Start your response directly with the JSON object."""
        
        return prompt
    
    def _extract_high_level_code(self, content: str) -> str:
        """Extract high-level code from LLM response"""
        import json
        
        # Clean the content
        content = content.strip()
        
        # Handle thinking mode responses: extract content after </think>
        if '<think>' in content and '</think>' in content:
            after_think = content.split('</think>', 1)[1].strip()
            if after_think:
                content = after_think
        
        # Try to extract JSON first
        try:
            # Look for JSON pattern
            json_match = re.search(r'\{[^}]*"high-level"[^}]*\}', content, re.DOTALL)
            if json_match:
                json_str = json_match.group(0)
                data = json.loads(json_str)
                if 'high-level' in data and data['high-level']:
                    return data['high-level'][0]
        except:
            pass
        
        # Try to parse the entire content as JSON
        try:
            data = json.loads(content)
            if isinstance(data, dict) and 'high-level' in data and data['high-level']:
                return data['high-level'][0]
        except:
            pass
        
        # Fallback: look for the last line that looks like a code
        lines = content.split('\n')
        for line in reversed(lines):
            line = line.strip()
            # Skip empty lines, JSON markers, and reasoning text
            if (line and 
                not line.startswith('{') and 
                not line.startswith('}') and 
                not line.startswith('"') and
                not line.startswith('High-level') and
                not line.startswith('-') and
                not line.startswith('Wait') and
                not line.startswith('</think>') and
                not line.startswith('I think') and
                not line.startswith('Maybe') and
                not line.startswith('Alternatively') and
                not line.startswith('Hmm') and
                not line.startswith('So') and
                not line.startswith('Let me') and
                not line.startswith('The') and
                not line.startswith('Another') and
                not line.startswith('Perhaps') and
                not line.startswith('But') and
                not line.startswith('Yes') and
                not line.startswith('No') and
                not line.startswith('How about') and
                not line.startswith('I need') and
                not line.startswith('I should') and
                not line.startswith('I can') and
                not line.startswith('I will') and
                not line.startswith('I would') and
                not line.startswith('I could') and
                not line.startswith('I might') and
                not line.startswith('I must') and
                not line.startswith('I have') and
                not line.startswith('I am') and
                not line.startswith('I was') and
                not line.startswith('I were') and
                not line.startswith('I had') and
                not line.startswith('I did') and
                not line.startswith('I do') and
                not line.startswith('I will') and
                not line.startswith('I would') and
                not line.startswith('I could') and
                not line.startswith('I might') and
                not line.startswith('I must') and
                not line.startswith('I have') and
                not line.startswith('I am') and
                not line.startswith('I was') and
                not line.startswith('I were') and
                not line.startswith('I had') and
                not line.startswith('I did') and
                not line.startswith('I do') and
                len(line) > 5 and  # Must be at least 5 characters
                len(line) < 200):  # Must be less than 200 characters
                return line
        
        # Final fallback: return the last non-empty line
        return lines[-1].strip() if lines else content.strip()
    
    async def enhance_corpus_with_high_level_codes(self, 
                                                 corpus_df: pd.DataFrame,
                                                 high_level_df: pd.DataFrame,
                                                 output_path: str) -> pd.DataFrame:
        """
        Enhance the corpus with high-level code mappings
        
        Args:
            corpus_df: Original corpus DataFrame
            high_level_df: DataFrame with high-level codes
            output_path: Path to save enhanced corpus
            
        Returns:
            Enhanced corpus DataFrame
        """
        print("🔄 Enhancing corpus with high-level codes...")
        
        # Create mapping from source codes to high-level codes
        code_to_high_level = {}
        for _, row in high_level_df.iterrows():
            high_level_code = row['high_level_code']
            for source_code in row['source_codes']:
                code_to_high_level[source_code] = high_level_code
        
        # Add high-level code column to corpus
        corpus_df['high_level_code'] = corpus_df['tag'].map(code_to_high_level)
        
        # Save enhanced corpus
        corpus_df.to_parquet(output_path, index=False)
        print(f"✅ Enhanced corpus saved to {output_path}")
        
        return corpus_df
    
    async def run_optimized_high_level_generation(self, 
                                                corpus_df: pd.DataFrame,
                                                existing_embeddings_path: str,
                                                output_dir: str) -> pd.DataFrame:
        """
        OPTIMIZED: Run the complete high-level code generation process with embedding reuse
        FIXED: Handles case where embeddings don't exist yet
        FIXED: Uses correct aiohttp approach for chat completions
        FIXED: Enables concurrency for high-level code generation
        
        Args:
            corpus_df: DataFrame with corpus data
            existing_embeddings_path: Path to existing embeddings file
            output_dir: Directory to save results
            
        Returns:
            Enhanced corpus DataFrame
        """
        print("🚀 Starting OPTIMIZED high-level code generation...")
        
        # Step 1: Try to load existing embeddings (OPTIMIZATION: No recomputation!)
        embeddings_df, embeddings_array, codes_list = await self.load_existing_embeddings(existing_embeddings_path)
        
        # Step 2: If no existing embeddings, generate new ones
        if embeddings_df is None:
            print("🔄 No existing embeddings found - generating new ones...")
            embeddings_path = os.path.join(output_dir, "embeddings.parquet")
            embeddings_df, embeddings_array, codes_list = await self.generate_embeddings_if_needed(
                corpus_df, embeddings_path
            )
        else:
            print("✅ Using existing embeddings - no recomputation needed!")
            embeddings_path = existing_embeddings_path
        
        # Step 3: Perform optimized clustering
        cluster_df = await self.perform_optimized_clustering(embeddings_array, output_dir)
        
        # Step 4: Generate high-level codes WITH CONCURRENCY
        high_level_df = await self.generate_high_level_codes(corpus_df, cluster_df, codes_list, output_dir)
        
        # Step 5: Enhance corpus with high-level codes
        enhanced_corpus_path = os.path.join(output_dir, "enhanced_corpus.parquet")
        enhanced_corpus_df = await self.enhance_corpus_with_high_level_codes(
            corpus_df, high_level_df, enhanced_corpus_path
        )
        
        print("✅ OPTIMIZED high-level code generation completed!")
        return enhanced_corpus_df


async def main():
    """Main function for testing"""
    parser = argparse.ArgumentParser(description='Optimized High-Level Code Generation')
    parser.add_argument('--corpus', type=str, required=True, help='Path to corpus parquet file')
    parser.add_argument('--embeddings', type=str, required=True, help='Path to existing embeddings parquet file')
    parser.add_argument('--output', type=str, required=True, help='Output directory')
    
    args = parser.parse_args()
    
    # Load corpus
    corpus_df = pd.read_parquet(args.corpus)
    print(f"📂 Loaded corpus with {len(corpus_df)} records")
    
    # Initialize generator
    generator = OptimizedHighLevelCodeGenerator()
    
    # Run optimized generation
    enhanced_corpus_df = await generator.run_optimized_high_level_generation(
        corpus_df=corpus_df,
        existing_embeddings_path=args.embeddings,
        output_dir=args.output
    )
    
    print(f"✅ Process completed! Enhanced corpus has {len(enhanced_corpus_df)} records")


if __name__ == "__main__":
    asyncio.run(main())
