#!/usr/bin/env python3
"""
NLI Classification with Load Balancer Integration - Enhanced with Advanced Features
Updated version that combines load balancing with all advanced features from main nli_classify.py
"""

import os
import sys
import asyncio
import aiohttp
import pandas as pd
import numpy as np
import time
from typing import List, Dict, Any, Tuple
from pathlib import Path
from tqdm import tqdm

# Add the current directory to path to import the load balancer
sys.path.append('.')

from .nli_load_balancer import NLIClassifierLoadBalancer

class EnhancedNLIClassifierWithLoadBalancer:
    """Enhanced NLI classifier with load balancing and all advanced features"""
    
    def __init__(self, 
                 batch_size: int = 10, 
                 primary_url: str = None,
                 secondary_url: str = None,
                 max_concurrency: int = 64):
        """
        Initialize the enhanced NLI classifier with load balancing
        
        Args:
            batch_size: Number of pairs to process in each batch
            primary_url: Primary NLI classifier server URL
            secondary_url: Secondary NLI classifier server URL
            max_concurrency: Maximum concurrent requests
        """
        self.batch_size = batch_size
        self.max_concurrency = max_concurrency
        self.load_balancer = NLIClassifierLoadBalancer(
            primary_url=primary_url,
            secondary_url=secondary_url
        )
        self.relationship_frequencies = {}
    
    async def __aenter__(self):
        await self.load_balancer.__aenter__()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.load_balancer.__aexit__(exc_type, exc_val, exc_tb)

    async def process_batch(self, batch: List[Dict], batch_id: int) -> List[Dict]:
        """Process a batch of similarity pairs with NLI classification using load balancer"""
        results = []
        
        # Use larger sub-batches to reduce HTTP overhead (send 20 pairs per request)
        http_batch_size = 20
        
        async def process_http_batch(pairs_batch: List[Dict]) -> List[Dict]:
            """Process multiple pairs in a single HTTP request"""
            # Create semaphore in current event loop if needed
            if not hasattr(self, 'semaphore') or self.semaphore is None:
                self.semaphore = asyncio.Semaphore(4)
                
            async with self.semaphore:
                try:
                    # Prepare pairs for the load balancer
                    pairs_list = [[pair['code_a'], pair['code_b']] for pair in pairs_batch]
                    
                    # Use load balancer to get predictions
                    result = await self.load_balancer.predict(pairs_list)
                    predictions = result['predictions']
                    
                    # Map results back to original pairs
                    batch_results = []
                    for i, (pair, prediction) in enumerate(zip(pairs_batch, predictions)):
                        batch_results.append({
                            'code_a': pair['code_a'],
                            'code_b': pair['code_b'],
                            'similarity': pair['similarity'],
                            'cluster_id': pair.get('cluster_id', -1),
                            'nli_label': prediction['pred_label'],
                            'confidence': prediction.get('prob_max', 0.8),
                            'batch_id': batch_id,
                            'datapoints_a': pair.get('datapoints_a', []),
                            'datapoints_b': pair.get('datapoints_b', [])
                        })
                    return batch_results
                    
                except Exception as e:
                    print(f"⚠️ Error processing HTTP batch: {e}")
                    # Return fallback results for all pairs in this batch
                    return [{
                        'code_a': pair['code_a'],
                        'code_b': pair['code_b'],
                        'similarity': pair['similarity'],
                        'cluster_id': pair.get('cluster_id', -1),
                        'nli_label': 'not_mergeable',
                        'confidence': 0.3,
                        'batch_id': batch_id,
                        'datapoints_a': pair.get('datapoints_a', []),
                        'datapoints_b': pair.get('datapoints_b', [])
                    } for pair in pairs_batch]
        
        # Split batch into smaller HTTP batches
        http_batches = [batch[i:i + http_batch_size] for i in range(0, len(batch), http_batch_size)]
        
        # Process all HTTP batches concurrently
        tasks = [process_http_batch(http_batch) for http_batch in http_batches]
        batch_results = await asyncio.gather(*tasks)
        
        # Flatten results
        for http_result in batch_results:
            results.extend(http_result)
        
        return results
    
    def create_relationship_matrix(self, nli_results: List[Dict]) -> Tuple[np.ndarray, List[str]]:
        """Create relationship matrix from NLI results for topological sorting"""
        print("🔗 Creating relationship matrix...")
        
        # Extract unique codes
        all_codes = set()
        for result in nli_results:
            all_codes.add(result['code_a'])
            all_codes.add(result['code_b'])
        
        unique_codes = sorted(list(all_codes))
        code_to_idx = {code: idx for idx, code in enumerate(unique_codes)}
        n_codes = len(unique_codes)
        
        # Initialize relationship matrix
        # 0 = no relationship, 1 = A_into_B, 2 = B_into_A, 3 = mutual, 4 = not_mergeable
        relationship_matrix = np.zeros((n_codes, n_codes), dtype=int)
        
        # Track relationship frequencies for pruning
        relationship_frequencies = {}
        
        # Fill matrix based on NLI results
        for result in nli_results:
            code_a = result['code_a']
            code_b = result['code_b']
            nli_label = result['nli_label']
            
            idx_a = code_to_idx[code_a]
            idx_b = code_to_idx[code_b]
            
            # Track frequency of each relationship type
            pair_key = (code_a, code_b)
            if pair_key not in relationship_frequencies:
                relationship_frequencies[pair_key] = {'count': 0, 'labels': {}}
            
            relationship_frequencies[pair_key]['count'] += 1
            if nli_label not in relationship_frequencies[pair_key]['labels']:
                relationship_frequencies[pair_key]['labels'][nli_label] = 0
            relationship_frequencies[pair_key]['labels'][nli_label] += 1
            
            if nli_label == "A_into_B":
                relationship_matrix[idx_a, idx_b] = 1
            elif nli_label == "B_into_A":
                relationship_matrix[idx_b, idx_a] = 1
            elif nli_label == "mutual":
                relationship_matrix[idx_a, idx_b] = 3
                relationship_matrix[idx_b, idx_a] = 3
            elif nli_label == "not_mergeable":
                relationship_matrix[idx_a, idx_b] = 4
                relationship_matrix[idx_b, idx_a] = 4
        
        print(f"   Created {n_codes}x{n_codes} relationship matrix")
        
        # Store frequency data for potential pruning
        self.relationship_frequencies = relationship_frequencies
        
        return relationship_matrix, unique_codes
    
    def prune_low_frequency_relationships(self, relationship_matrix: np.ndarray, 
                                        unique_codes: List[str], 
                                        min_frequency: int = 2,
                                        min_frequency_ratio: float = 0.1) -> np.ndarray:
        """
        Prune low-frequency relationships from the relationship matrix
        
        Args:
            relationship_matrix: Original relationship matrix
            unique_codes: List of unique codes
            min_frequency: Minimum absolute frequency for a relationship to be kept
            min_frequency_ratio: Minimum ratio of most common label for a relationship to be kept
        
        Returns:
            Pruned relationship matrix
        """
        print(f"🔍 Pruning low-frequency relationships...")
        print(f"   Min frequency: {min_frequency}, Min frequency ratio: {min_frequency_ratio}")
        
        if not hasattr(self, 'relationship_frequencies'):
            print("   ⚠️ No frequency data available, returning original matrix")
            return relationship_matrix
        
        code_to_idx = {code: idx for idx, code in enumerate(unique_codes)}
        pruned_matrix = relationship_matrix.copy()
        
        # Track pruning statistics
        total_relationships = 0
        pruned_relationships = 0
        kept_relationships = 0
        
        # Analyze each relationship pair
        for (code_a, code_b), freq_data in self.relationship_frequencies.items():
            total_relationships += 1
            count = freq_data['count']
            labels = freq_data['labels']
            
            # Calculate frequency ratio (most common label / total count)
            if count > 0:
                max_label_count = max(labels.values())
                frequency_ratio = max_label_count / count
            else:
                frequency_ratio = 0
            
            # Check if relationship should be pruned
            should_prune = (count < min_frequency) or (frequency_ratio < min_frequency_ratio)
            
            if should_prune:
                # Remove relationship from matrix
                if code_a in code_to_idx and code_b in code_to_idx:
                    idx_a = code_to_idx[code_a]
                    idx_b = code_to_idx[code_b]
                    pruned_matrix[idx_a, idx_b] = 0
                    pruned_matrix[idx_b, idx_a] = 0
                pruned_relationships += 1
            else:
                kept_relationships += 1
        
        # Print pruning statistics
        print(f"   📊 Pruning statistics:")
        print(f"     • Total relationships analyzed: {total_relationships}")
        print(f"     • Relationships kept: {kept_relationships}")
        print(f"     • Relationships pruned: {pruned_relationships}")
        print(f"     • Pruning rate: {(pruned_relationships/total_relationships)*100:.1f}%")
        
        # Count remaining relationships by type
        remaining_relationships = np.sum(pruned_matrix > 0)
        print(f"   📈 Remaining relationships in matrix: {remaining_relationships}")
        
        return pruned_matrix
    
    def save_frequency_statistics(self, output_dir: str):
        """Save relationship frequency statistics for analysis"""
        if not hasattr(self, 'relationship_frequencies'):
            print("   ⚠️ No frequency data available to save")
            return
        
        print("📊 Saving relationship frequency statistics...")
        
        # Convert to DataFrame for easier analysis
        frequency_data = []
        for (code_a, code_b), freq_data in self.relationship_frequencies.items():
            count = freq_data['count']
            labels = freq_data['labels']
            
            # Calculate frequency ratio
            if count > 0:
                max_label_count = max(labels.values())
                frequency_ratio = max_label_count / count
                most_common_label = max(labels, key=labels.get)
            else:
                frequency_ratio = 0
                most_common_label = 'none'
            
            frequency_data.append({
                'code_a': code_a,
                'code_b': code_b,
                'total_count': count,
                'frequency_ratio': frequency_ratio,
                'most_common_label': most_common_label,
                'label_distribution': str(labels)
            })
        
        # Save to parquet
        freq_df = pd.DataFrame(frequency_data)
        freq_path = os.path.join(output_dir, "relationship_frequencies.parquet")
        freq_df.to_parquet(freq_path, index=False)
        
        print(f"   💾 Saved frequency statistics: {freq_path}")
        
        # Print summary statistics
        if frequency_data:
            total_pairs = len(frequency_data)
            avg_frequency = np.mean([d['total_count'] for d in frequency_data])
            avg_ratio = np.mean([d['frequency_ratio'] for d in frequency_data])
            
            print(f"   📈 Frequency summary:")
            print(f"     • Total relationship pairs: {total_pairs}")
            print(f"     • Average frequency: {avg_frequency:.2f}")
            print(f"     • Average frequency ratio: {avg_ratio:.2f}")

    async def classify_similarities_optimized(self, similarity_pairs: List[Dict], 
                                           output_dir: str) -> Tuple[List[Dict], np.ndarray, List[str]]:
        """Classify similarity pairs with high concurrency, load balancing, and advanced features"""
        print("🧠 Running enhanced NLI classification with load balancing (64 concurrent)...")
        
        if not similarity_pairs:
            print("⚠️ No similarity pairs to classify")
            return [], np.array([]), []
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        all_results = []
        batch_id = 0
        
        # Use larger batch size for better concurrency
        effective_batch_size = max(self.batch_size, 100)  # Minimum 100 pairs per batch
        
        # Process in batches with timing
        total_start_time = time.time()
        
        for i in tqdm(range(0, len(similarity_pairs), effective_batch_size), total=len(range(0, len(similarity_pairs), effective_batch_size)), desc="NLI batches", ncols=100, mininterval=1):
            batch = similarity_pairs[i:i + effective_batch_size]
            
            batch_start_time = time.time()
            batch_results = await self.process_batch(batch, batch_id)
            batch_time = time.time() - batch_start_time
            
            # Save batch independently
            batch_df = pd.DataFrame(batch_results)
            batch_path = os.path.join(output_dir, f"nli_batch_{batch_id:04d}.parquet")
            batch_df.to_parquet(batch_path)
            
            # Performance logging
            throughput = len(batch) / batch_time if batch_time > 0 else 0
            # Reduced verbose logging - only print every 10th batch
            if batch_id % 10 == 0 or batch_id == 0:
                print(f"   ⚡ Batch {batch_id}: {len(batch)} pairs in {batch_time:.2f}s ({throughput:.1f} pairs/sec)")
            
            all_results.extend(batch_results)
            batch_id += 1
        
        total_time = time.time() - total_start_time
        overall_throughput = len(all_results) / total_time if total_time > 0 else 0
        print(f"🚀 NLI Performance: {len(all_results)} total pairs in {total_time:.2f}s ({overall_throughput:.1f} pairs/sec)")
        
        print(f"✅ NLI classification completed: {len(all_results)} pairs labeled in {batch_id} batches")
        
        # Create relationship matrix
        relationship_matrix, unique_codes = self.create_relationship_matrix(all_results)
        
        # Apply frequency-based pruning
        pruned_matrix = self.prune_low_frequency_relationships(
            relationship_matrix, unique_codes, 
            min_frequency=2,  # Minimum 2 occurrences
            min_frequency_ratio=0.6  # At least 60% consistency in labeling
        )
        
        # Save frequency statistics for analysis
        self.save_frequency_statistics(output_dir)
        
        # Print label distribution
        if all_results:
            labels = [result['nli_label'] for result in all_results]
            unique_labels, counts = np.unique(labels, return_counts=True)
            print("   Label distribution:")
            for label, count in zip(unique_labels, counts):
                percentage = (count / len(labels)) * 100
                print(f"     {label}: {count} ({percentage:.1f}%)")
        
        # Print server statistics
        stats = self.load_balancer.get_server_stats()
        print(f"\n📊 Load Balancer Statistics:")
        print(f"   Total requests: {stats['total_requests']}")
        print(f"   Healthy servers: {stats['healthy_servers']}/{stats['total_servers']}")
        
        for server_stat in stats['servers']:
            print(f"   {server_stat['url']}: {server_stat['status']} "
                  f"(requests: {server_stat['request_count']}, "
                  f"errors: {server_stat['error_count']}, "
                  f"response_time: {server_stat['response_time']:.2f}s)")
        
        return all_results, pruned_matrix, unique_codes
    
    def save_relationship_matrix(self, relationship_matrix: np.ndarray, 
                               unique_codes: List[str], output_path: str):
        """Save relationship matrix and code mapping"""
        print(f"💾 Saving relationship matrix to {output_path}")
        
        # Save matrix as numpy array
        matrix_path = output_path.replace('.parquet', '_matrix.npy')
        np.save(matrix_path, relationship_matrix)
        
        # Save code mapping
        codes_df = pd.DataFrame({
            'code': unique_codes,
            'index': range(len(unique_codes))
        })
        codes_df.to_parquet(output_path)
        
        print(f"   Saved {relationship_matrix.shape[0]}x{relationship_matrix.shape[1]} matrix")
        print(f"   Saved {len(unique_codes)} unique codes")

# Standalone function for compatibility
async def classify_similarities_optimized(similarity_pairs: List[Dict], 
                                        output_dir: str,
                                        batch_size: int = 10,
                                        max_concurrency: int = 64) -> Tuple[List[Dict], np.ndarray, List[str]]:
    """
    Enhanced NLI classification function with load balancing and advanced features
    
    Args:
        similarity_pairs: List of similarity pairs from cosine similarity
        output_dir: Directory to save independent batches
        batch_size: Number of pairs to process in each batch
        model_config: Model configuration dict (if None, uses load balancer)
        
    Returns:
        Tuple of (all_results, relationship_matrix, unique_codes)
    """
    async with EnhancedNLIClassifierWithLoadBalancer(
        batch_size=batch_size
    ) as classifier:
        return await classifier.classify_similarities_optimized(similarity_pairs, output_dir)

# Example usage
async def test_enhanced_nli_with_load_balancer():
    """Test the enhanced NLI classifier with load balancer"""
    print("🧪 Testing Enhanced NLI Classifier with Load Balancer")
    print("=" * 60)
    
    # Create test data
    test_pairs = [
        {
            'code_a': 'user authentication',
            'code_b': 'login process',
            'similarity': 0.85,
            'cluster_id': 1
        },
        {
            'code_a': 'database connection',
            'code_b': 'query execution',
            'similarity': 0.72,
            'cluster_id': 2
        },
        {
            'code_a': 'file upload',
            'code_b': 'data processing',
            'similarity': 0.68,
            'cluster_id': 3
        }
    ]
    
    output_dir = "test_nli_output"
    
    try:
        results, relationship_matrix, unique_codes = await classify_similarities_optimized(
            test_pairs, output_dir, batch_size=2
        )
        
        print(f"✅ Classification successful: {len(results)} results")
        print(f"✅ Unique codes: {len(unique_codes)}")
        print(f"✅ Relationship matrix: {relationship_matrix.shape}")
        
    except Exception as e:
        print(f"❌ Test failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    asyncio.run(test_enhanced_nli_with_load_balancer())
