import gc
import os
import time
import requests
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from tqdm import tqdm
import atexit
import signal
import sys
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor

from dotenv import load_dotenv
load_dotenv()

# Optional dependencies
try:
    import pyarrow as pa
    import pyarrow.parquet as pq
    HAVE_PA = True
except Exception:
    HAVE_PA = False

# Constants
EMBEDDING_MODEL = os.getenv("DEFAULT_EMBEDDING_MODEL")
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL")
CHUNK_SIZE = int(os.getenv("EMBEDDING_CHUNK_SIZE", "1000"))  # Reduced from 5000 to 1000 for better performance
MAX_CONCURRENCY = int(os.getenv("VLLM_MAX_CONCURRENCY", "128"))
REQUEST_TIMEOUT = int(os.getenv("VLLM_TIMEOUT", "120"))

# Paths (now using temp-files)
THIS_DIR = os.path.abspath(os.path.dirname(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..", ".."))
TEMP_FILES_DIR = os.getenv("TEMP_FILES_DIR", os.path.join(PROJECT_ROOT, "temp_files"))
CODE_CHUNKS_PARQUET = os.getenv("CODE_CHUNKS_PARQUET", os.path.join(TEMP_FILES_DIR, "code.parquet"))
EMBEDDINGS_PARQUET = os.getenv("EMBEDDINGS_PARQUET", os.path.join(TEMP_FILES_DIR, "embeddings.parquet"))

# Global cleanup function
def cleanup_resources():
    """Clean up resources to prevent memory leaks"""
    gc.collect()
    print("🧹 Cleaned up resources")

# Register cleanup function
atexit.register(cleanup_resources)

# Signal handler for graceful shutdown
def signal_handler(signum, frame):
    """Handle signals for graceful shutdown"""
    print(f"\n🛑 Received signal {signum}, cleaning up...")
    cleanup_resources()
    sys.exit(0)

# Register signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

class AsyncVllmClient:
    """Async vLLM client for embeddings with high concurrency"""
    
    def __init__(self, base_url: str, timeout: int = 120):
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.session: Optional[aiohttp.ClientSession] = None
        
    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self
        
    async def __aexit__(self, exc_type, exc, tb):
        if self.session:
            await self.session.close()
    
    async def embeddings(self, model: str, inputs: List[str]) -> Optional[Dict[str, Any]]:
        """Get embeddings asynchronously from vLLM server"""
        if not self.session:
            raise RuntimeError("Client session not initialized")
            
        payload = {
            "model": model,
            "input": inputs
        }
        
        try:
            # Use shorter timeout for individual requests to prevent hanging
            request_timeout = min(self.timeout, 30)  # Max 30 seconds per request
            async with self.session.post(
                f"{self.base_url}/v1/embeddings", 
                json=payload, 
                timeout=aiohttp.ClientTimeout(total=request_timeout, connect=10)
            ) as response:
                if response.status == 200:
                    return await response.json()
                else:
                    txt = await response.text()
                    print(f"⚠️ embeddings HTTP {response.status}: {txt}")
                    return None
        except asyncio.TimeoutError:
            print(f"⚠️ embeddings request timeout after {request_timeout}s")
            raise
        except Exception as e:
            print(f"⚠️ embeddings request failed: {e}")
            return None

class DirectVllmClient:
    """Direct vLLM client for embeddings without proxy (synchronous fallback)"""
    
    def __init__(self, base_url: str, timeout: int = 120):
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self.session = requests.Session()
    
    def embeddings(self, model: str, inputs: List[str]) -> Dict[str, Any]:
        """Get embeddings directly from vLLM server"""
        url = f"{self.base_url}/v1/embeddings"
        payload = {
            "model": model,
            "input": inputs
        }
        response = self.session.post(url, json=payload, timeout=self.timeout)
        response.raise_for_status()
        return response.json()
    
    def __del__(self):
        """Cleanup session on deletion"""
        if hasattr(self, 'session'):
            self.session.close()

def _ensure_pyarrow():
    if not HAVE_PA:
        raise RuntimeError("pyarrow is required to read/write Parquet. Please install pyarrow.")

def load_texts_from_parquet(parquet_path: str = CODE_CHUNKS_PARQUET) -> List[str]:
    """Load text strings from a Parquet file or directory of Parquet files.
    Tries common column names; falls back to the first string-like column."""
    _ensure_pyarrow()
    paths: List[str] = []
    if os.path.isdir(parquet_path):
        paths = [
            os.path.join(parquet_path, f)
            for f in os.listdir(parquet_path)
            if f.endswith(".parquet")
        ]
        paths.sort()
    else:
        paths = [parquet_path]
    if not paths:
        raise FileNotFoundError(f"No Parquet files found at {parquet_path}")

    texts: List[str] = []
    candidate_cols = ["text", "code", "chunk", "content", "snippet"]
    for p in paths:
        table = pq.read_table(p)
        col_name = None
        for c in candidate_cols:
            if c in table.column_names:
                col_name = c
                break
        if col_name is None:
            # pick first column with large binary/utf8 type
            for name in table.column_names:
                t = table.schema.field(name).type
                if pa.types.is_string(t) or pa.types.is_large_string(t):
                    col_name = name
                    break
        if col_name is None:
            raise ValueError(f"No suitable text column found in {p}. Columns: {table.column_names}")
        texts.extend(table[col_name].to_pylist())
    # Normalize to str
    texts = ["" if t is None else str(t) for t in texts]
    return texts

def _batched(seq: List[str], batch_size: int) -> Tuple[int, List[str]]:
    total = len(seq)
    idx = 0
    while idx < total:
        end = min(idx + batch_size, total)
        yield (idx // batch_size) + 1, seq[idx:end]
        idx = end

async def build_embeddings_parquet_async(
    corpus_df: pd.DataFrame = None,
    input_parquet: str = CODE_CHUNKS_PARQUET,
    output_parquet: str = EMBEDDINGS_PARQUET,
    chunk_size: int = CHUNK_SIZE,
    max_concurrency: int = MAX_CONCURRENCY,
) -> Tuple[str, np.ndarray, List[str], pd.DataFrame]:
    """Async version of build_embeddings_parquet with high concurrency for better performance"""
    
    _ensure_pyarrow()

    if not VLLM_EMBEDDING_URL:
        raise ValueError("VLLM_EMBEDDING_URL environment variable not set")
    if not EMBEDDING_MODEL:
        raise ValueError("DEFAULT_EMBEDDING_MODEL environment variable not set")

    os.makedirs(TEMP_FILES_DIR, exist_ok=True)

    # Use provided corpus_df or load from file
    if corpus_df is None:
        if input_parquet is None:
            raise ValueError("Either corpus_df or input_parquet must be provided")
        corpus_df = pd.read_parquet(input_parquet)
        print(f"Loaded {len(corpus_df):,} records from {input_parquet}")
    else:
        print(f"Using provided corpus_df with {len(corpus_df):,} records")
    
    print(f"Columns: {corpus_df.columns.tolist()}")

    # Extract texts for embedding - use 'tag' column if available, otherwise fall back to 'text'
    if 'tag' in corpus_df.columns:
        texts = corpus_df['tag'].tolist()
        print(f"Using 'tag' column for embeddings")
    elif 'text' in corpus_df.columns:
        texts = corpus_df['text'].tolist()
        print(f"Using 'text' column for embeddings")
    else:
        # Fall back to original method
        texts = load_texts_from_parquet(input_parquet)
        print(f"Using original text loading method")
    
    print(f"Generating embeddings for {len(texts):,} texts")
    print(f"🔗 Using async client with concurrency={max_concurrency}, chunk_size={chunk_size}")

    # Create async client
    async with AsyncVllmClient(VLLM_EMBEDDING_URL, REQUEST_TIMEOUT) as client:
        print(f"🔗 Connected to vLLM server: {VLLM_EMBEDDING_URL}")
        print(f"🤖 Using embedding model: {EMBEDDING_MODEL}")

        all_embeddings: List[List[float]] = []
        # Reduce concurrency significantly to avoid overwhelming the server
        adjusted_concurrency = min(max_concurrency, 8)  # Cap at 8 concurrent requests
        semaphore = asyncio.Semaphore(adjusted_concurrency)
        print(f"   🔧 Using conservative concurrency: {adjusted_concurrency} (from {max_concurrency})")

        start_time = time.time()
        total_batches = (len(texts) + chunk_size - 1) // chunk_size

        # Process batches with high concurrency
        for batch_idx, batch in tqdm(_batched(texts, chunk_size), total=total_batches, desc="Embedding batches"):
            batch_start = time.time()
            
            print(f"   📦 Processing batch {batch_idx}/{total_batches} ({len(batch)} texts)")
            
            # Create tasks for this batch
            batch_tasks = []
            for text in batch:
                task = _embed_single_text_with_semaphore(client, EMBEDDING_MODEL, [text], semaphore)
                batch_tasks.append(task)
            
            # Process all tasks in this batch simultaneously
            batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
            
            # Extract embeddings from results with better error handling
            batch_embeddings = []
            failed_count = 0
            for i, result in enumerate(batch_results):
                if isinstance(result, Exception):
                    failed_count += 1
                    print(f"      ⚠️ Request {i} failed: {result}")
                    # Use zero embedding as fallback
                    batch_embeddings.append([0.0] * 1024)  # Assuming 1024 dimensions
                elif result and isinstance(result, dict):
                    try:
                        embedding = result["data"][0]["embedding"]
                        batch_embeddings.append(embedding)
                    except Exception as e:
                        failed_count += 1
                        print(f"      ⚠️ Parse failed for request {i}: {e}")
                        batch_embeddings.append([0.0] * 1024)
                else:
                    failed_count += 1
                    print(f"      ⚠️ Request {i} returned None")
                    batch_embeddings.append([0.0] * 1024)
            
            # Continue processing even if some failed
            if failed_count > 0:
                print(f"      ⚠️ {failed_count}/{len(batch)} requests failed in this batch, continuing...")
            
            all_embeddings.extend(batch_embeddings)
            
            batch_time = time.time() - batch_start
            successful = sum(1 for r in batch_results if not isinstance(r, Exception) and r is not None)
            throughput = successful / batch_time if batch_time > 0 else 0
            
            print(f"      ✅ Batch completed in {batch_time:.2f}s: {successful}/{len(batch)} successful")
            print(f"      ⚡ Throughput: {throughput:.2f} texts/second")
            
            # Clean up memory
            gc.collect()
            
            # Small delay between batches to avoid overwhelming the server
            if batch_idx < total_batches - 1:  # Don't delay after the last batch
                await asyncio.sleep(0.5)

        if not all_embeddings:
            raise RuntimeError("No embeddings were produced.")
        
        # Check success rate
        total_requests = len(texts)
        successful_embeddings = sum(1 for emb in all_embeddings if any(x != 0.0 for x in emb))
        success_rate = successful_embeddings / total_requests if total_requests > 0 else 0
        
        print(f"📊 Embedding Summary:")
        print(f"   Total requests: {total_requests}")
        print(f"   Successful: {successful_embeddings}")
        print(f"   Failed: {total_requests - successful_embeddings}")
        print(f"   Success rate: {success_rate:.1%}")
        
        if success_rate < 0.5:
            print(f"⚠️ Warning: Low success rate ({success_rate:.1%}), but continuing with available embeddings...")

        # Create enhanced output with all original columns plus embeddings
        if 'tag' in corpus_df.columns or 'text' in corpus_df.columns:
            # Add embeddings to the original corpus DataFrame
            corpus_df['embedding'] = all_embeddings
            
            # Write enhanced Parquet with all original columns plus embeddings
            corpus_df.to_parquet(output_parquet)
            print(f"✓ Wrote {len(corpus_df):,} embeddings to {output_parquet}")
            print(f"✓ Preserved all original columns: {corpus_df.columns.tolist()}")
            
            # Extract codes for return
            codes = corpus_df['tag'].tolist() if 'tag' in corpus_df.columns else corpus_df['text'].tolist()
        else:
            # Fall back to original simple format
            emb_array = pa.array(all_embeddings, type=pa.list_(pa.float32()))
            text_array = pa.array(texts, type=pa.string())
            table = pa.Table.from_arrays([text_array, emb_array], names=["text", "embedding"])
            pq.write_table(table, output_parquet)
            print(f"✓ Wrote {len(texts):,} embeddings to {output_parquet}")
            
            # For fallback case, create a simple corpus_df
            corpus_df = pd.DataFrame({
                'text': texts,
                'embedding': all_embeddings
            })
            codes = texts

        total_time = time.time() - start_time
        print(f"✅ Embedding generation completed in {total_time:.2f}s")
        
        # Clean up memory
        del all_embeddings
        gc.collect()
        
        # Return embeddings array, codes list, and enhanced corpus DataFrame
        embeddings_array = np.array(corpus_df['embedding'].tolist())
        return output_parquet, embeddings_array, codes, corpus_df

async def _embed_single_text_with_semaphore(client: AsyncVllmClient, model: str, texts: List[str], semaphore: asyncio.Semaphore) -> Optional[Dict]:
    """Embed a single text with proper semaphore control for concurrency and error handling"""
    async with semaphore:
        max_retries = 3
        retry_delay = 1.0  # Start with 1 second delay
        
        for attempt in range(max_retries):
            try:
                return await client.embeddings(model, texts)
            except asyncio.TimeoutError:
                if attempt < max_retries - 1:
                    print(f"      ⏰ Timeout on attempt {attempt + 1}/{max_retries}, retrying in {retry_delay}s...")
                    await asyncio.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print(f"      ❌ Timeout after {max_retries} attempts")
                    return None
            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"      ⚠️ Error on attempt {attempt + 1}/{max_retries}: {e}, retrying in {retry_delay}s...")
                    await asyncio.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print(f"      ❌ Failed after {max_retries} attempts: {e}")
                    return None
        
        return None

async def build_embeddings_parquet(
    corpus_df: pd.DataFrame = None,
    input_parquet: str = CODE_CHUNKS_PARQUET,
    output_parquet: str = EMBEDDINGS_PARQUET,
    chunk_size: int = CHUNK_SIZE,
    client: Optional[DirectVllmClient] = None,
    use_async: bool = True,
) -> Tuple[str, np.ndarray, List[str], pd.DataFrame]:
    """Read texts from Parquet, compute embeddings in chunks, and write a single Parquet with columns:
    - text: utf8
    - embedding: list<float32>
    - All original columns preserved for datapoint tracking
    Returns the output parquet path, embeddings array, codes list, and enhanced corpus DataFrame.
    """
    
    # Use async version by default for better performance
    if use_async:
        print("🚀 Using async embedding generation for better performance...")
        # Check if we're already in an event loop
        try:
            loop = asyncio.get_running_loop()
            # We're in an event loop, create a task
            return await build_embeddings_parquet_async(
                corpus_df=corpus_df,
                input_parquet=input_parquet,
                output_parquet=output_parquet,
                chunk_size=chunk_size
            )
        except RuntimeError:
            # No event loop running, use asyncio.run
            return asyncio.run(build_embeddings_parquet_async(
                corpus_df=corpus_df,
                input_parquet=input_parquet,
                output_parquet=output_parquet,
                chunk_size=chunk_size
            ))
    
    # Fallback to synchronous version
    print("⚠️ Using synchronous embedding generation (slower)...")
    _ensure_pyarrow()

    if not VLLM_EMBEDDING_URL:
        raise ValueError("VLLM_EMBEDDING_URL environment variable not set")
    if not EMBEDDING_MODEL:
        raise ValueError("DEFAULT_EMBEDDING_MODEL environment variable not set")

    os.makedirs(TEMP_FILES_DIR, exist_ok=True)

    # Use provided corpus_df or load from file
    if corpus_df is None:
        corpus_df = pd.read_parquet(input_parquet)
        print(f"Loaded {len(corpus_df):,} records from {input_parquet}")
    else:
        print(f"Using provided corpus_df with {len(corpus_df):,} records")
    
    print(f"Columns: {corpus_df.columns.tolist()}")

    # Extract texts for embedding - use 'tag' column if available, otherwise fall back to 'text'
    if 'tag' in corpus_df.columns:
        texts = corpus_df['tag'].tolist()
        print(f"Using 'tag' column for embeddings")
    elif 'text' in corpus_df.columns:
        texts = corpus_df['text'].tolist()
        print(f"Using 'text' column for embeddings")
    else:
        # Fall back to original method
        texts = load_texts_from_parquet(input_parquet)
        print(f"Using original text loading method")
    
    print(f"Generating embeddings for {len(texts):,} texts")

    client = client or DirectVllmClient(VLLM_EMBEDDING_URL)
    print(f"🔗 Connected to vLLM server: {VLLM_EMBEDDING_URL}")
    print(f"🤖 Using embedding model: {EMBEDDING_MODEL}")

    all_embeddings: List[List[float]] = []

    start_time = time.time()
    total_batches = (len(texts) + chunk_size - 1) // chunk_size

    for batch_idx, batch in tqdm(_batched(texts, chunk_size), total=total_batches, desc="Embedding batches"):
        try:
            t0 = time.time()
            result = client.embeddings(EMBEDDING_MODEL, batch)
            t1 = time.time()
            batch_embeddings = [item["embedding"] for item in result.get("data", [])]
            if len(batch_embeddings) != len(batch):
                raise RuntimeError(f"Embedding count mismatch: got {len(batch_embeddings)} for {len(batch)} inputs")
            all_embeddings.extend(batch_embeddings)
            print(f"  ✓ Batch {batch_idx}: {len(batch)} texts in {t1 - t0:.2f}s")
        except Exception as e:
            print(f"  ✗ Batch {batch_idx}: Error - {e}")
            # Skip or pad? We'll skip failing batch to avoid corrupt output
            continue
        finally:
            gc.collect()

    if not all_embeddings:
        raise RuntimeError("No embeddings were produced.")

    # Create enhanced output with all original columns plus embeddings
    if 'tag' in corpus_df.columns or 'text' in corpus_df.columns:
        # Add embeddings to the original corpus DataFrame
        corpus_df['embedding'] = all_embeddings
        
        # Write enhanced Parquet with all original columns plus embeddings
        corpus_df.to_parquet(output_parquet)
        print(f"✓ Wrote {len(corpus_df):,} embeddings to {output_parquet}")
        print(f"✓ Preserved all original columns: {corpus_df.columns.tolist()}")
        
        # Extract codes for return
        codes = corpus_df['tag'].tolist() if 'tag' in corpus_df.columns else corpus_df['text'].tolist()
    else:
        # Fall back to original simple format
        emb_array = pa.array(all_embeddings, type=pa.list_(pa.float32()))
        text_array = pa.array(texts, type=pa.string())
        table = pa.Table.from_arrays([text_array, emb_array], names=["text", "embedding"])
        pq.write_table(table, output_parquet)
        print(f"✓ Wrote {len(texts):,} embeddings to {output_parquet}")
        
        # For fallback case, create a simple corpus_df
        corpus_df = pd.DataFrame({
            'text': texts,
            'embedding': all_embeddings
        })
        codes = texts

    total_time = time.time() - start_time
    print(f"✅ Embedding generation completed in {total_time:.2f}s")
    
    # Clean up memory
    del all_embeddings
    gc.collect()
    
    # Return embeddings array, codes list, and enhanced corpus DataFrame
    embeddings_array = np.array(corpus_df['embedding'].tolist())
    return output_parquet, embeddings_array, codes, corpus_df

# Backward-compatible alias
get_embeddings_chunked = build_embeddings_parquet