#!/usr/bin/env python3
"""
Build corpus script for test inference evaluation
"""

import asyncio
import json
import os
import random
import re
import sys
import time
from typing import Any, Dict, List, Tuple, Optional
import pandas as pd
import numpy as np
from dotenv import load_dotenv

def retry_on_connection_failure(max_retries: int = 3, base_delay: float = 2.0):
    """Decorator that only retries on connection-related failures"""
    def decorator(func):
        async def wrapper(*args, **kwargs):
            last_exception = None
            
            for attempt in range(max_retries + 1):
                try:
                    return await func(*args, **kwargs)
                except Exception as e:
                    last_exception = e
                    
                    # Only retry on connection-related errors
                    connection_errors = [
                        'ClientConnectorError',
                        'TimeoutError', 
                        'ConnectionError',
                        'Connect call failed',
                        'Connection refused',
                        'Connection reset',
                        'Connection timeout',
                        'DNS resolution failed',
                        'nodename nor servname provided'
                    ]
                    
                    error_str = str(e)
                    error_type = type(e).__name__
                    
                    is_connection_error = (
                        any(err in error_str for err in connection_errors) or
                        any(err in error_type for err in connection_errors)
                    )
                    
                    # Don't retry if it's not a connection error
                    if not is_connection_error:
                        print(f"❌ Non-connection error, not retrying: {error_type}: {error_str[:100]}...")
                        raise e
                    
                    # Don't retry on the last attempt
                    if attempt == max_retries:
                        break
                    
                    # Calculate delay with exponential backoff
                    delay = base_delay * (2 ** attempt)
                    delay += random.uniform(0, 1)  # Add jitter
                    
                    print(f"⚠️  Connection failed (attempt {attempt + 1}/{max_retries + 1}): {error_type}")
                    print(f"�� Retrying in {delay:.1f} seconds...")
                    
                    await asyncio.sleep(delay)
            
            # If we get here, all retries failed
            print(f"❌ Connection failed after {max_retries + 1} attempts")
            raise last_exception
        
        return wrapper
    return decorator
import aiohttp

# Load environment variables
load_dotenv("../../main_pipeline/.env")

# Configuration
MAX_CONCURRENCY = 16
REQUEST_TIMEOUT = 120
BATCH_SIZE = 10

def chunk_text(text: str, chunk_size: int = 2048, overlap: int = 200) -> List[str]:
    """Split text into overlapping chunks based on word count"""
    words = text.split()
    if len(words) <= chunk_size:
        return [text]
    
    chunks = []
    start = 0
    
    while start < len(words):
        end = start + chunk_size
        chunk_words = words[start:end]
        chunk = " ".join(chunk_words)
        chunks.append(chunk)
        
        if end >= len(words):
            break
            
        start = end - overlap
    
    return chunks
class VLLMClient:
    """Async vLLM client for handling multiple requests"""
    
    def __init__(self, model_url: str, timeout: int = 120):
        self.model_url = model_url
        self.timeout = timeout
        self.session = None
    
    async def __aenter__(self):
        self.session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout))
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()
    
    @retry_on_connection_failure(max_retries=3, base_delay=2.0)
    async def chat_completion(self, messages: List[Dict], model: str, temperature: float = 0.1, max_tokens: int = 1000) -> Dict:
        """Make a chat completion request"""
        payload = {
            'model': model,
            'messages': messages,
            'temperature': temperature,
            'max_tokens': max_tokens
        }
        
        async with self.session.post(f'{self.model_url}/v1/chat/completions', json=payload) as response:
            return await response.json()
    
    async def embeddings(self, texts: List[str], model: str) -> Dict:
        """Make an embeddings request"""
        payload = {
            'model': model,
            'input': texts
        }
        
        async with self.session.post(f'{self.model_url}/v1/embeddings', json=payload) as response:
            return await response.json()

class OptimizedTestInferenceCorpusBuilder:
    """Optimized corpus builder with full concurrency and vectorized similarity"""
    
    def __init__(self, question: str, train_corpus_path: str, hierarchical_tree_path: str, use_hierarchical_linkage: bool = True):
        self.question = question
        self.train_corpus_path = train_corpus_path
        self.hierarchical_tree_path = hierarchical_tree_path
        self.use_hierarchical_linkage = use_hierarchical_linkage        
        
        # Log hierarchical linkage setting
        if self.use_hierarchical_linkage:
            print("✅ Hierarchical linkage enabled - will retrieve parent/grandparent codes")
        else:
            print("⚠️  Hierarchical linkage disabled - parent/grandparent codes will be None")
        
        # Load environment variables
        self.model_url = os.getenv("VLLM_QWEN_32B_URL_2", "http://localhost:8000")
        self.embed_url = os.getenv("VLLM_EMBEDDING_URL", "http://localhost:8001")
        self.model_name = os.getenv("VLLM_QWEN_32B_MODEL", "Qwen/Qwen3-32B")
        self.embed_model = os.getenv("DEFAULT_EMBEDDING_MODEL", "qwen3-embed-0.6b")
        
        # Load training data
        self.train_codes = []
        self.train_embeddings = None
        self.hierarchical_tree = None
        
        # Statistics
        self.stats = type('Stats', (), {
            'total_chunks': 0,
            'open_coding_requests': 0,
            'replacement_requests': 0,
            'parent_retrieval_requests': 0,
            'embedding_requests': 0,
            'codes_generated': 0,
            'codes_replaced': 0,
            'codes_discarded': 0
        })()
        
        self._load_training_data()
        self._load_hierarchical_tree()
    
    def _load_training_data(self):
        """Load training corpus"""
        print(f"📂 Loading training corpus from {self.train_corpus_path}")
        train_df = pd.read_parquet(self.train_corpus_path)
        self.train_codes = train_df['tag'].tolist()
        print(f"✅ Loaded {len(self.train_codes)} training codes")
    
    def _load_hierarchical_tree(self):
        """Load hierarchical tree for parent/grandparent retrieval"""
        if self.hierarchical_tree_path and os.path.exists(self.hierarchical_tree_path):
            print(f"📂 Loading hierarchical tree from {self.hierarchical_tree_path}")
            with open(self.hierarchical_tree_path, 'r') as f:
                self.hierarchical_tree = json.load(f)
            print(f"✅ Loaded hierarchical tree")
        else:
            print(f"⚠️  Hierarchical tree not found at {self.hierarchical_tree_path or 'None'}")
            self.hierarchical_tree = None
    
    def _clean_llm_response(self, content: str) -> str:
        """Clean LLM response to extract answer"""
        # Remove <think> tags and content within them
        content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL | re.IGNORECASE)
        # For our new format, just return the cleaned content
        return content.strip()
    
    def _clean_open_coding_response(self, content: str) -> str:
        """Clean LLM response for open coding (JSON format)"""
        # Remove <think> tags and content within them
        content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL | re.IGNORECASE)
        
        # Remove any text before the first {
        content = re.sub(r'^[^{]*', '', content)
        
        # Remove any text after the last }
        content = re.sub(r'}[^}]*$', '}', content)
        
        return content.strip()
    
    async def _compute_train_embeddings(self):
        """Compute embeddings for training codes"""
        if self.train_embeddings is not None:
            return
            
        print(f"🔄 Computing embeddings for training codes using LLM...")
        
        # Process in batches
        embeddings = []
        for i in range(0, len(self.train_codes), BATCH_SIZE):
            batch = self.train_codes[i:i + BATCH_SIZE]
            print(f"  Processing batch {i//BATCH_SIZE + 1}/{(len(self.train_codes) + BATCH_SIZE - 1)//BATCH_SIZE}")
            
            async with VLLMClient(self.embed_url, REQUEST_TIMEOUT) as client:
                response = await client.embeddings(batch, self.embed_model)
                
                if response and 'data' in response:
                    batch_embeddings = [item['embedding'] for item in response['data']]
                    embeddings.extend(batch_embeddings)
                    self.stats.embedding_requests += 1
        
        # Convert to numpy array and normalize
        self.train_embeddings = np.array(embeddings)
        self.train_embeddings = self.train_embeddings / np.linalg.norm(self.train_embeddings, axis=1, keepdims=True)
        
        print(f"✅ Computed and normalized embeddings for {len(self.train_codes)} training codes")
    
    async def _find_similar_codes_vectorized(self, codes: List[str], top_k: int = 10) -> Tuple[List[Tuple[str, float]], Dict]:
        """Find similar codes using vectorized cosine similarity"""
        print(f"🔍 Debug: _find_similar_codes_vectorized called with {len(codes)} codes, top_k={top_k}")
        print(f"📊 Debug: train_embeddings is None: {self.train_embeddings is None}")
        
        if self.train_embeddings is None:
            await self._compute_train_embeddings()
        
        # Compute embeddings for input codes
        async with VLLMClient(self.embed_url, REQUEST_TIMEOUT) as client:
            response = await client.embeddings(codes, self.embed_model)
            self.stats.embedding_requests += 1
            
            if response and 'data' in response:
                code_embeddings = np.array([item['embedding'] for item in response['data']])
                code_embeddings = code_embeddings / np.linalg.norm(code_embeddings, axis=1, keepdims=True)
            else:
                return [], {}
        
        # Vectorized cosine similarity
        similarities = np.dot(code_embeddings, self.train_embeddings.T)
        
        results = []
        final_codes_embeddings = {}
        
        for i, code in enumerate(codes):
            print(f"🔍 Debug: Processing code {i+1}: {code[:50]}...")
            
            # Get top-k similar codes
            top_indices = np.argsort(similarities[i])[-top_k:][::-1]
            print(f"📊 Debug: Top indices: {top_indices}")
            print(f"📊 Debug: Train codes list length: {len(self.train_codes)}")
            
            similar_codes = []
            for idx in top_indices:
                if idx < len(self.train_codes):
                    similarity = similarities[i][idx]
                    similar_codes.append((self.train_codes[idx], similarity))
                    print(f"📊 Debug: Added code {idx}: {self.train_codes[idx][:50]}... (sim: {similarity:.3f})")
            
            print(f"📊 Debug: Final similar codes count: {len(similar_codes)}")
            results.append(similar_codes)
            final_codes_embeddings[code] = code_embeddings[i]
        
        return results, final_codes_embeddings
    
    async def _retrieve_hierarchical_codes(self, code: str) -> Tuple[str, str, int]:
        """Retrieve parent and grandparent codes from hierarchical tree"""
        if not self.use_hierarchical_linkage:
            return None, None, 0
        
        if not self.hierarchical_tree:
            return None, None, 0
        
        self.stats.parent_retrieval_requests += 1
        
        # Search for the code in the hierarchical tree
        for node in self.hierarchical_tree.get('nodes', []):
            if node.get('code') == code:
                parent_code = node.get('parent_code')
                grandparent_code = node.get('grandparent_code')
                hierarchy_level = node.get('hierarchy_level', 0)
                return parent_code, grandparent_code, hierarchy_level
        
        return None, None, 0
    
    async def _open_coding(self, chunk: str) -> List[str]:
        """Perform open coding on a chunk"""
        from prompts import OPEN_CODING_PROMPT
        
        messages = [
            {"role": "user", "content": OPEN_CODING_PROMPT.format(
                question=self.question,
                chunk=chunk
            )}
        ]
        
        async with VLLMClient(self.model_url, REQUEST_TIMEOUT) as client:
            response = await client.chat_completion(messages, self.model_name, temperature=0.1, max_tokens=1000)
            
            if response and 'choices' in response:
                content = response['choices'][0]['message']['content'].strip()
                print(f"🔍 Raw open coding response: {content[:100]}...")
                
                # Clean the response
                cleaned_content = self._clean_open_coding_response(content)
                print(f"🧹 Cleaned open coding response: {cleaned_content[:100]}...")
                
                try:
                    data = json.loads(cleaned_content)
                    codes = data.get('codes', [])
                    print(f"📊 Parsed {len(codes)} codes from open coding")
                    self.stats.open_coding_requests += 1
                    self.stats.codes_generated += len(codes)
                    print(f"🔄 Returning {len(codes)} codes from open coding")
                    return codes
                except json.JSONDecodeError as e:
                    print(f"❌ Error parsing open coding response: {e}")
                    return []
        
        return []
    
    async def _evaluate_code_replacement(self, current_code: str, candidate_codes: List[str]) -> Dict[str, Any]:
        """Evaluate if any candidate code can replace the current code"""
        from prompts import CODE_REPLACEMENT_PROMPT
        
        # Format candidate codes for the prompt
        candidate_text = "\n".join([f"{i+1}. {code}" for i, code in enumerate(candidate_codes)])
        
        messages = [
            {"role": "user", "content": CODE_REPLACEMENT_PROMPT.format(
                question=self.question,
                current_code=current_code,
                candidate_codes=candidate_text
            )}
        ]
        
        async with VLLMClient(self.model_url, REQUEST_TIMEOUT) as client:
            response = await client.chat_completion(messages, self.model_name, temperature=0.1, max_tokens=512)
            
            if response and 'choices' in response:
                content = response['choices'][0]['message']['content'].strip()
                content = self._clean_llm_response(content)
                
                # Process simple yes/no response
                response_text = content.strip().lower()
                self.stats.replacement_requests += 1
                
                # Debug: print what we actually received
                print(f'  🤖 LLM Response: "{content.strip()}" -> processed as: "{response_text}"')
                
                if response_text == 'none':
                    return {"can_replace": False, "reason": "LLM found no suitable replacement", "confidence": 0.8}
                elif response_text.isdigit():
                    # LLM selected a single numbered candidate
                    selected_index = int(response_text) - 1  # Convert to 0-based index
                    if 1 <= int(response_text) <= len(candidate_codes):
                        selected_code = candidate_codes[selected_index]
                        return {
                            "can_replace": True,
                            "reason": f"LLM selected candidate {response_text}",
                            "confidence": 0.8,
                            "selected_code": selected_code
                        }
                    else:
                        return {"can_replace": False, "reason": f"Invalid candidate number: {response_text}", "confidence": 0.0}
                else:
                    return {"can_replace": False, "reason": f"Invalid response format: {response_text}", "confidence": 0.0}
        
        return {"can_replace": False, "reason": "No response from LLM", "confidence": 0.0}
    
    async def _process_single_open_code(self, open_code: str) -> Optional[Dict[str, Any]]:
        """Process a single open code with replacement logic"""
        print(f"🔍 Processing open code: {open_code[:50]}...")
        
        # Find similar codes
        similar_codes_list, _ = await self._find_similar_codes_vectorized([open_code], top_k=100)
        similar_codes = similar_codes_list[0] if similar_codes_list else []
        
        if not similar_codes:
            self.stats.codes_discarded += 1
            return None  # Discard codes with no similar codes found
        
        # Get candidate codes
        candidate_codes = [code_info[0] if isinstance(code_info, tuple) else code_info["code"] for code_info in similar_codes]
        
        # Evaluate replacement
        replacement_result = await self._evaluate_code_replacement(open_code, candidate_codes)
        
        if replacement_result["can_replace"]:
            selected_code = replacement_result["selected_code"]
            # Find the similarity score for the selected code
            selected_similarity = 0.0
            for code_info in similar_codes:
                if isinstance(code_info, tuple):
                    if code_info[0] == selected_code:
                        selected_similarity = code_info[1]
                        break
                elif isinstance(code_info, dict) and code_info.get("code") == selected_code:
                    selected_similarity = code_info.get("similarity", 0.0)
                    break
            
            # Retrieve hierarchical information
            parent_code, grandparent_code, hierarchy_level = await self._retrieve_hierarchical_codes(selected_code)
            
            self.stats.codes_replaced += 1
            return {
                "code": selected_code,
                "original_code": open_code,
                "similarity_score": selected_similarity,
                "replacement_confidence": replacement_result["confidence"],
                "replacement_reason": replacement_result["reason"],
                "parent_code": parent_code,
                "grandparent_code": grandparent_code,
                "hierarchy_level": hierarchy_level
            }
        else:
            self.stats.codes_discarded += 1
            return None  # Discard codes that cannot be replaced
    
    async def _process_chunk_with_replacement(self, chunk: str) -> List[Dict[str, Any]]:
        """Process a chunk with open coding and replacement"""
        # Open coding
        open_codes = await self._open_coding(chunk)
        
        if not open_codes:
            return []
        
        print(f"📊 Open codes received: {len(open_codes)}")
        
        # Process all open codes concurrently
        print(f"🔥 Processing {len(open_codes)} open codes concurrently")
        tasks = [self._process_single_open_code(code) for code in open_codes]
        results = await asyncio.gather(*tasks)
        
        # Filter out None results (discarded codes)
        filtered_results = [result for result in results if result is not None]
        print(f"📊 Kept {len(filtered_results)} codes, discarded {len(results) - len(filtered_results)} codes")
        
        return filtered_results
    
    async def process_datapoints(self, datapoints: List[str]) -> List[Dict[str, Any]]:
        """Process datapoints with FULL CONCURRENCY and OPTIMIZED similarity"""
        print(f"🚀 Processing {len(datapoints)} datapoints with OPTIMIZED CONCURRENCY...")
        print(f"📊 Training codes available: {len(self.train_codes)}")
        print(f"⚡ Max concurrency: {MAX_CONCURRENCY}")
        print(f"🔧 Using vectorized cosine similarity")
        
        # Compute embeddings if not already computed
        if self.train_embeddings is None:
            await self._compute_train_embeddings()
        
        start_time = time.time()
        results = []
        
        
        # STEP 1: Chunk ALL datapoints and create ALL tasks
        print(f"\n📝 STEP 1: Chunking all datapoints and creating tasks...")
        all_tasks = []
        datapoint_chunk_mapping = []
        
        for i, datapoint in enumerate(datapoints):
            print(f"📝 Processing datapoint {i+1}/{len(datapoints)}")
            
            # Chunk the datapoint and add to mapping
            chunks = chunk_text(datapoint, chunk_size=2048, overlap=200)
            self.stats.total_chunks += len(chunks)
            print(f"  📊 Generated {len(chunks)} chunks")
            
            # Add chunk tasks to global list
            chunk_tasks = []
            for j, chunk in enumerate(chunks):
                task = self._process_chunk_with_replacement(chunk)
                all_tasks.append(task)
                datapoint_chunk_mapping.append({
                    "datapoint_idx": i,
                    "chunk_idx": j,
                    "chunk": chunk
                })
            
        # STEP 2: Process ALL tasks in parallel
        print(f"\n⚡ STEP 2: Processing ALL {len(all_tasks)} tasks in parallel...")
        
        # Create semaphore to limit concurrent requests
        semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
        
        async def process_with_semaphore(task):
            async with semaphore:
                return await task
        
        # Process all tasks concurrently
        chunk_results = await asyncio.gather(*[process_with_semaphore(task) for task in all_tasks])
        print(f"✅ All {len(all_tasks)} tasks completed!")
        
        # STEP 3: Group results by datapoint
        print(f"\n📊 STEP 3: Grouping results by datapoint...")
        results = []
        for i in range(len(datapoints)):
            datapoint_codes = []
            seen_codes = set()
            
            # Collect all codes from chunks belonging to this datapoint
            for j, chunk_info in enumerate(datapoint_chunk_mapping):
                if chunk_info["datapoint_idx"] == i:
                    chunk_result = chunk_results[j]
                    for code_info in chunk_result:
                        code = code_info["code"]
                        if code not in seen_codes:
                            datapoint_codes.append(code_info)
                            seen_codes.add(code)
            
            chunk_count = sum(1 for info in datapoint_chunk_mapping if info["datapoint_idx"] == i)
            results.append({
                "datapoint": i,
                "chunks": chunk_count,
                "codes": datapoint_codes
            })
            print(f"   Datapoint {i+1}: {len(datapoint_codes)} unique codes from {chunk_count} chunks")
        
            
            results.append({
                "datapoint": i,
                "chunks": len(chunks),
                "codes": datapoint_codes
            })
        
        end_time = time.time()
        print(f"✅ Processing completed in {end_time - start_time:.2f} seconds")
        
        # Print statistics
        print(f"\n📊 Processing Statistics:")
        print(f"  Total datapoints: {len(datapoints)}")
        print(f"  Total chunks: {self.stats.total_chunks}")
        print(f"  Open coding requests: {self.stats.open_coding_requests}")
        print(f"  Replacement requests: {self.stats.replacement_requests}")
        print(f"  Parent retrieval requests: {self.stats.parent_retrieval_requests}")
        print(f"  Embedding requests: {self.stats.embedding_requests}")
        print(f"  Codes generated: {self.stats.codes_generated}")
        print(f"  Codes replaced: {self.stats.codes_replaced}")
        print(f"  Codes discarded: {self.stats.codes_discarded}")
        print(f"  Success rate: {self.stats.codes_replaced / max(self.stats.codes_generated, 1) * 100:.1f}%")
        
        return results, final_codes_embeddings if "final_codes_embeddings" in locals() else {}
    
    def save_results(self, results: List[Dict[str, Any]], output_path: str):
        """Save results to files"""
        # Save detailed results
        detailed_results = []
        for result in results:
            try:
                for code_info in result['codes']:
                    detailed_results.append({
                        'datapoint': result['datapoint'],
                        'code': code_info["code"],
                        'original_code': code_info['original_code'],
                        'similarity_score': code_info['similarity_score'],
                        'replacement_confidence': code_info['replacement_confidence'],
                        'replacement_reason': code_info['replacement_reason'],
                        'parent_code': code_info['parent_code'],
                        'grandparent_code': code_info['grandparent_code'],
                        'hierarchy_level': code_info['hierarchy_level']
                    })
            except TypeError as e:
                print(f"⚠️  Skipping result due to TypeError in 'codes' iteration: {e}. Result: {result.get('datapoint', 'N/A')}")
                continue
            except KeyError as e:
                print(f"⚠️  Skipping result due to KeyError in 'codes' iteration: {e}. Result: {result.get('datapoint', 'N/A')}")
                continue
        
        # Save to parquet
        df = pd.DataFrame(detailed_results)
        df.to_parquet(output_path, index=False)
        print(f"✅ Saved detailed results to {output_path}")
        
        # Save summary
        summary_path = output_path.replace('.parquet', '_summary.json')
        summary = {
            'total_datapoints': len(results),
            'total_codes': len(detailed_results),
            'unique_codes': len(set(row['code'] for row in detailed_results)),
            'statistics': {
                'total_chunks': self.stats.total_chunks,
                'open_coding_requests': self.stats.open_coding_requests,
                'replacement_requests': self.stats.replacement_requests,
                'parent_retrieval_requests': self.stats.parent_retrieval_requests,
                'embedding_requests': self.stats.embedding_requests,
                'codes_generated': self.stats.codes_generated,
                'codes_replaced': self.stats.codes_replaced,
                'codes_discarded': self.stats.codes_discarded,
                'success_rate': self.stats.codes_replaced / max(self.stats.codes_generated, 1) * 100
            }
        }
        
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        print(f"✅ Saved summary to {summary_path}")

async def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='Build corpus for test inference evaluation')
    parser.add_argument('--question', required=True, help='Research question')
    parser.add_argument('--train_corpus', required=True, help='Path to training corpus parquet file')
    parser.add_argument('--hierarchical_tree', default=None, help='Path to hierarchical tree JSON file (optional)')
    parser.add_argument('--test_data', required=True, help='Path to test data CSV file')
    parser.add_argument('--output', required=True, help='Output path for results')
    parser.add_argument("--use_hierarchical_linkage", action="store_true", default=True, help="Use parent and grandparent linkage (default: True)")    
    args = parser.parse_args()
    
    # Load test data
    print(f"📂 Loading test data from {args.test_data}")
    test_df = pd.read_csv(args.test_data)
    datapoints = test_df['text'].tolist()
    print(f"✅ Loaded {len(datapoints)} test datapoints")
    
    # Initialize builder
    builder = OptimizedTestInferenceCorpusBuilder(
        question=args.question,
        train_corpus_path=args.train_corpus,
        hierarchical_tree_path=args.hierarchical_tree or "dummy_path.json",  # Pass dummy path if None
        use_hierarchical_linkage=args.use_hierarchical_linkage
    )
    
    # Compute embeddings first
    await builder._compute_train_embeddings()
    
    # Process datapoints
    results, _ = await builder.process_datapoints(datapoints)
    
    # Save results
    builder.save_results(results, args.output)

if __name__ == "__main__":
    asyncio.run(main())
