import os
import sys
import json
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
import torch
import sqlite3
import hashlib
from tqdm import tqdm
from datetime import datetime

# Add parent directory to path for local imports
sys.path.append("../")

# Import common utilities  
from common.sft_prompts import CLASSES
from common.model_path import LLM_API_CONFIGS


class TextAttackDB:
    """Database manager for text attack batch processing with recovery support"""
    
    def __init__(self, db_dir: str = "./text_attack_cache"):
        """Initialize database connection"""
        self.db_dir = db_dir
        os.makedirs(db_dir, exist_ok=True)
        self.db_path = os.path.join(db_dir, "text_attacks.db")
        self.init_database()
    
    def init_database(self):
        """Initialize database tables"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Create table for batch processing status
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS attack_batches (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    experiment_hash TEXT NOT NULL,
                    batch_id INTEGER NOT NULL,
                    batch_size INTEGER NOT NULL,
                    start_idx INTEGER NOT NULL,
                    end_idx INTEGER NOT NULL,
                    status TEXT NOT NULL DEFAULT 'pending',
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    completed_at TIMESTAMP,
                    UNIQUE(experiment_hash, batch_id)
                )
            ''')
            
            # Create table for attacked text results
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS attacked_texts (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    experiment_hash TEXT NOT NULL,
                    batch_id INTEGER NOT NULL,
                    node_id INTEGER NOT NULL,
                    original_text TEXT NOT NULL,
                    attacked_text TEXT NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    UNIQUE(experiment_hash, batch_id, node_id)
                )
            ''')
            
            conn.commit()
    
    def get_experiment_hash(self, dataset: str, llm_provider: str, llm_model: str, 
                          emb_type: str, ptb_rate: float, seed: int, 
                          setting: str = "transductive") -> str:
        """Generate unique hash for experiment parameters"""
        params = f"{dataset}_{llm_provider}_{llm_model}_{emb_type}_{ptb_rate}_{seed}_{setting}"
        return hashlib.md5(params.encode()).hexdigest()
    
    def check_experiment_completion(self, experiment_hash: str) -> bool:
        """Check if experiment is already completed (final log exists)"""
        # Check if final log file exists based on experiment hash
        log_patterns = [
            f"./logs_text_attack/*/llm_*/results_*_{experiment_hash[:8]}_*.json",
            f"./logs_text_attack/*/llm_*/attacked_texts_*_{experiment_hash[:8]}.json"
        ]
        
        import glob
        for pattern in log_patterns:
            if glob.glob(pattern):
                return True
        return False
    
    def init_experiment_batches(self, experiment_hash: str, total_prompts: int, 
                              batch_size: int = 10) -> List[Tuple[int, int, int]]:
        """Initialize batch tracking for an experiment"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Check if batches already exist
            cursor.execute('''
                SELECT batch_id, start_idx, end_idx FROM attack_batches 
                WHERE experiment_hash = ? ORDER BY batch_id
            ''', (experiment_hash,))
            
            existing_batches = cursor.fetchall()
            if existing_batches:
                print(f"✓ Found {len(existing_batches)} existing batches for experiment")
                return existing_batches
            
            # Create new batches
            batches = []
            for i in range(0, total_prompts, batch_size):
                batch_id = i // batch_size
                start_idx = i
                end_idx = min(i + batch_size, total_prompts)
                
                cursor.execute('''
                    INSERT OR IGNORE INTO attack_batches 
                    (experiment_hash, batch_id, batch_size, start_idx, end_idx)
                    VALUES (?, ?, ?, ?, ?)
                ''', (experiment_hash, batch_id, batch_size, start_idx, end_idx))
                
                batches.append((batch_id, start_idx, end_idx))
            
            conn.commit()
            print(f"✓ Initialized {len(batches)} batches for experiment")
            return batches
    
    def get_pending_batches(self, experiment_hash: str) -> List[Tuple[int, int, int]]:
        """Get list of pending (not completed) batches"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT batch_id, start_idx, end_idx FROM attack_batches 
                WHERE experiment_hash = ? AND status != 'completed'
                ORDER BY batch_id
            ''', (experiment_hash,))
            
            pending_batches = cursor.fetchall()
            return pending_batches
    
    def get_completed_batches(self, experiment_hash: str) -> List[Tuple[int, int, int]]:
        """Get list of completed batches"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT batch_id, start_idx, end_idx FROM attack_batches 
                WHERE experiment_hash = ? AND status = 'completed'
                ORDER BY batch_id
            ''', (experiment_hash,))
            
            completed_batches = cursor.fetchall()
            return completed_batches
    
    def mark_batch_started(self, experiment_hash: str, batch_id: int):
        """Mark batch as started"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                UPDATE attack_batches 
                SET status = 'processing' 
                WHERE experiment_hash = ? AND batch_id = ?
            ''', (experiment_hash, batch_id))
            conn.commit()
    
    def save_batch_results(self, experiment_hash: str, batch_id: int, 
                          node_ids: List[int], original_texts: List[str], 
                          attacked_texts: List[str]):
        """Save batch results to database"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Save attacked texts
            for node_id, orig_text, att_text in zip(node_ids, original_texts, attacked_texts):
                cursor.execute('''
                    INSERT OR REPLACE INTO attacked_texts 
                    (experiment_hash, batch_id, node_id, original_text, attacked_text)
                    VALUES (?, ?, ?, ?, ?)
                ''', (experiment_hash, batch_id, node_id, orig_text, att_text))
            
            # Mark batch as completed
            cursor.execute('''
                UPDATE attack_batches 
                SET status = 'completed', completed_at = CURRENT_TIMESTAMP
                WHERE experiment_hash = ? AND batch_id = ?
            ''', (experiment_hash, batch_id))
            
            conn.commit()
            print(f"✓ Saved batch {batch_id} with {len(node_ids)} results")
    
    def get_all_attacked_texts(self, experiment_hash: str) -> Dict[int, str]:
        """Get all attacked texts for an experiment, ordered by node_id"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT node_id, attacked_text FROM attacked_texts 
                WHERE experiment_hash = ?
                ORDER BY node_id
            ''', (experiment_hash,))
            
            results = cursor.fetchall()
            return {node_id: attacked_text for node_id, attacked_text in results}
    
    def is_experiment_complete(self, experiment_hash: str) -> bool:
        """Check if all batches for an experiment are completed"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT COUNT(*) as total, 
                       SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed
                FROM attack_batches 
                WHERE experiment_hash = ?
            ''', (experiment_hash,))
            
            total, completed = cursor.fetchone()
            return total > 0 and total == completed
    
    def get_experiment_progress(self, experiment_hash: str) -> Dict:
        """Get progress information for an experiment"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT 
                    COUNT(*) as total_batches,
                    SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed_batches,
                    SUM(CASE WHEN status = 'processing' THEN 1 ELSE 0 END) as processing_batches,
                    SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending_batches
                FROM attack_batches 
                WHERE experiment_hash = ?
            ''', (experiment_hash,))
            
            total, completed, processing, pending = cursor.fetchone()
            
            # Get total attacked texts count
            cursor.execute('''
                SELECT COUNT(*) FROM attacked_texts WHERE experiment_hash = ?
            ''', (experiment_hash,))
            total_texts = cursor.fetchone()[0]
            
            return {
                "total_batches": total or 0,
                "completed_batches": completed or 0,
                "processing_batches": processing or 0,
                "pending_batches": pending or 0,
                "total_attacked_texts": total_texts,
                "completion_percentage": (completed / total * 100) if total > 0 else 0
            }
    
    def cleanup_experiment(self, experiment_hash: str):
        """Clean up database entries for an experiment after successful completion"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('DELETE FROM attacked_texts WHERE experiment_hash = ?', (experiment_hash,))
            cursor.execute('DELETE FROM attack_batches WHERE experiment_hash = ?', (experiment_hash,))
            
            conn.commit()
            print(f"✓ Cleaned up database entries for experiment {experiment_hash[:8]}")
    
    def list_experiments(self) -> List[Dict]:
        """List all experiments in database with their status"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT DISTINCT experiment_hash FROM attack_batches
            ''')
            
            experiments = []
            for (exp_hash,) in cursor.fetchall():
                progress = self.get_experiment_progress(exp_hash)
                experiments.append({
                    "experiment_hash": exp_hash,
                    "short_hash": exp_hash[:8],
                    **progress
                })
            
            return experiments 


async def call_chat_api(messages: List[Dict], api_key: str, base_url: str, model: str) -> str:
    """Call chat API with proper error handling"""
    import aiohttp
    import time
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    payload = {
        "model": model,
        "messages": messages,
        "temperature": 0.7,
        "max_tokens": 1000
    }
    
    max_retries = 3
    retry_delay = 2
    
    for attempt in range(max_retries):
        try:
            timeout = aiohttp.ClientTimeout(total=30)
            async with aiohttp.ClientSession(timeout=timeout) as session:
                async with session.post(f"{base_url}/chat/completions", 
                                      json=payload, headers=headers) as response:
                    
                    # Check HTTP status
                    if response.status != 200:
                        error_text = await response.text()
                        print(f"⚠ HTTP {response.status}: {error_text}")
                        if response.status == 429:  # Rate limit
                            print(f"Rate limited, waiting {retry_delay * (attempt + 1)} seconds...")
                            await asyncio.sleep(retry_delay * (attempt + 1))
                            continue
                        elif response.status in [401, 403]:  # Auth errors
                            raise Exception(f"Authentication error: {error_text}")
                        else:
                            raise Exception(f"HTTP {response.status}: {error_text}")
                    
                    data = await response.json()
                    
                    # Validate response format
                    if "choices" not in data:
                        print(f"⚠ Invalid response format: {data}")
                        if "error" in data:
                            error_msg = data["error"].get("message", str(data["error"]))
                            print(f"API error: {error_msg}")
                            
                            # Handle specific errors
                            if "quota" in error_msg.lower() or "limit" in error_msg.lower():
                                raise Exception(f"API quota/limit error: {error_msg}")
                            elif "model" in error_msg.lower():
                                raise Exception(f"Model error: {error_msg}")
                            else:
                                raise Exception(f"API error: {error_msg}")
                        else:
                            raise Exception(f"Invalid response format: missing 'choices' field")
                    
                    if len(data["choices"]) == 0:
                        raise Exception("Empty choices in API response")
                    
                    if "message" not in data["choices"][0]:
                        raise Exception("Invalid choice format: missing 'message' field")
                    
                    if "content" not in data["choices"][0]["message"]:
                        raise Exception("Invalid message format: missing 'content' field")
                    
                    return data["choices"][0]["message"]["content"]
                    
        except aiohttp.ClientConnectorDNSError as e:
            print(f"⚠ DNS resolution failed for {base_url}: {str(e)}")
            if attempt < max_retries - 1:
                print(f"Retrying in {retry_delay} seconds... (attempt {attempt + 1}/{max_retries})")
                await asyncio.sleep(retry_delay)
            else:
                print("❌ Max retries reached. Please check:")
                print("1. Your internet connection")
                print("2. DNS settings")
                print(f"3. If {base_url} is accessible")
                raise Exception(f"DNS resolution failed after {max_retries} attempts: {str(e)}")
                
        except aiohttp.ClientConnectorError as e:
            print(f"⚠ Connection error to {base_url}: {str(e)}")
            if attempt < max_retries - 1:
                print(f"Retrying in {retry_delay} seconds... (attempt {attempt + 1}/{max_retries})")
                await asyncio.sleep(retry_delay)
            else:
                raise Exception(f"Connection failed after {max_retries} attempts: {str(e)}")
                
        except asyncio.TimeoutError:
            print(f"⚠ Request timeout (attempt {attempt + 1}/{max_retries})")
            if attempt < max_retries - 1:
                await asyncio.sleep(retry_delay)
            else:
                raise Exception("Request timeout after multiple attempts")
                
        except Exception as e:
            if "quota" in str(e).lower() or "authentication" in str(e).lower():
                # Don't retry auth/quota errors
                raise e
            elif attempt < max_retries - 1:
                print(f"⚠ Unexpected error: {str(e)}")
                print(f"Retrying in {retry_delay} seconds... (attempt {attempt + 1}/{max_retries})")
                await asyncio.sleep(retry_delay)
            else:
                raise e
    
    raise Exception(f"Failed after {max_retries} attempts")


async def generate_adversarial_texts_batch_with_db(prompts: List[str], target_nodes: List[int], 
                                                  api_key: str, base_url: str, model: str,
                                                  db: TextAttackDB, experiment_hash: str,
                                                  batch_size: int = 10) -> List[str]:
    """Generate adversarial texts in batches with database recovery support"""
    results = [""] * len(prompts)  # Initialize with empty strings
    
    # Initialize experiment batches in database
    batches = db.init_experiment_batches(experiment_hash, len(prompts), batch_size)
    
    # Get progress information
    progress = db.get_experiment_progress(experiment_hash)
    print(f"Experiment progress: {progress['completed_batches']}/{progress['total_batches']} batches completed "
          f"({progress['completion_percentage']:.1f}%)")
    
    # If experiment is already complete, load results
    if db.is_experiment_complete(experiment_hash):
        print("✓ All batches completed, loading results from database")
        attacked_texts_dict = db.get_all_attacked_texts(experiment_hash)
        
        # Map back to original order
        for i, node_id in enumerate(target_nodes):
            if node_id in attacked_texts_dict:
                results[i] = attacked_texts_dict[node_id]
            else:
                # Fallback to original text if not found
                results[i] = prompts[i]
        
        return results
    
    # Get pending batches
    pending_batches = db.get_pending_batches(experiment_hash)
    print(f"Processing {len(pending_batches)} pending batches...")
    
    # Process pending batches with progress bar
    with tqdm(total=len(pending_batches), desc="Processing batches") as pbar:
        for batch_id, start_idx, end_idx in pending_batches:
            try:
                # Mark batch as started
                db.mark_batch_started(experiment_hash, batch_id)
                
                batch_prompts = prompts[start_idx:end_idx]
                batch_node_ids = target_nodes[start_idx:end_idx]
                batch_tasks = []
                
                # Create API tasks for this batch
                for prompt in batch_prompts:
                    messages = [{"role": "user", "content": prompt}]
                    task = call_chat_api(messages, api_key, base_url, model)
                    batch_tasks.append(task)
                
                # Execute batch
                batch_results = await asyncio.gather(*batch_tasks)
                batch_cleaned_results = [clean_generated_text(result) for result in batch_results]
                
                # Save batch results to database
                db.save_batch_results(experiment_hash, batch_id, batch_node_ids, 
                                    batch_prompts, batch_cleaned_results)
                
                # Update results array
                for i, result in enumerate(batch_cleaned_results):
                    results[start_idx + i] = result
                
                pbar.update(1)
                
                # Rate limiting between batches
                if batch_id < len(pending_batches) - 1:
                    await asyncio.sleep(1.0)
                    
            except Exception as e:
                print(f"❌ Batch {batch_id} failed: {str(e)}")
                # Don't update progress, batch remains in pending state
                raise e
    
    # Final check: load all results from database
    attacked_texts_dict = db.get_all_attacked_texts(experiment_hash)
    for i, node_id in enumerate(target_nodes):
        if node_id in attacked_texts_dict:
            results[i] = attacked_texts_dict[node_id]
    
    return results


async def generate_adversarial_texts_batch(prompts: List[str], api_key: str, 
                                         base_url: str, model: str, batch_size: int = 10) -> List[str]:
    """Generate adversarial texts in batches (legacy function for compatibility)"""
    results = []
    
    for i in tqdm(range(0, len(prompts), batch_size)):
        batch_prompts = prompts[i:i + batch_size]
        batch_tasks = []
        
        for prompt in batch_prompts:
            messages = [{"role": "user", "content": prompt}]
            task = call_chat_api(messages, api_key, base_url, model)
            batch_tasks.append(task)
        
        batch_results = await asyncio.gather(*batch_tasks)
        results.extend([clean_generated_text(result) for result in batch_results])
        
        # Rate limiting between batches
        if i + batch_size < len(prompts):
            await asyncio.sleep(1.0)
    
    return results


def create_adversarial_prompts(dataset: str, target_nodes: torch.Tensor, texts: List[str], 
                              labels: torch.Tensor, label_names: List[str], 
                              edge_index: torch.Tensor, all_labels: torch.Tensor) -> List[str]:
    """Create prompts for LLM-based adversarial text generation with neighbor awareness"""
    all_classes = CLASSES.get(dataset.lower(), label_names)
    classes_str = ", ".join([f"'{cls}'" for cls in all_classes])
    
    prompts = []
    for i, (node_id, text, label_idx) in enumerate(zip(target_nodes, texts, labels)):
        current_label = label_names[label_idx.item()]
        
        # Find neighbors of the target node
        neighbors = edge_index[1][edge_index[0] == node_id]
        if len(neighbors) > 0:
            neighbor_labels = all_labels[neighbors]
            neighbor_label_names = [label_names[idx.item()] for idx in neighbor_labels]
            neighbor_label_counts = {}
            for neighbor_label in neighbor_label_names:
                neighbor_label_counts[neighbor_label] = neighbor_label_counts.get(neighbor_label, 0) + 1
            
            # Get unique neighbor labels and their counts
            unique_neighbor_labels = list(set(neighbor_label_names))
            neighbor_info = f"Neighbor labels: {', '.join(unique_neighbor_labels)} (counts: {neighbor_label_counts})"
        else:
            unique_neighbor_labels = []
            neighbor_info = "No neighbors found"
        
        # Determine forbidden labels (current + neighbors)
        forbidden_labels = {current_label}
        forbidden_labels.update(unique_neighbor_labels)
        
        # Find allowed labels (not in forbidden set)
        allowed_labels = [cls for cls in all_classes if cls not in forbidden_labels]
        
        # If all labels are forbidden, choose the least frequent neighbor label (but not current label)
        if not allowed_labels:
            if len(unique_neighbor_labels) > 0:
                # Count frequency of each neighbor label and choose the least frequent one
                neighbor_counts = {label: neighbor_label_names.count(label) for label in unique_neighbor_labels}
                target_label = min(neighbor_counts.keys(), key=lambda x: neighbor_counts[x])
                target_instruction = f"Since all classes appear in neighbors, generate text for the least frequent neighbor class: '{target_label}'"
            else:
                # If no neighbors, just avoid current label
                allowed_labels = [cls for cls in all_classes if cls != current_label]
                target_label = allowed_labels[0] if allowed_labels else all_classes[0]
                target_instruction = f"Generate text for class: '{target_label}'"
        else:
            # Choose from allowed labels, preferably the most different one
            target_label = allowed_labels[0]  # Could be enhanced with semantic distance
            target_instruction = f"Generate text for a class that differs from both node and neighbors. Preferred target: '{target_label}'"
        
        prompt = f"""Graph node classification task:

        Available classes: {classes_str}

        Target node {node_id.item()}:
        Original text: "{text}"
        Original label: "{current_label}"
        {neighbor_info}

        Task: Modify the text to make it as different as possible from the original, while maintaining similar length.
        Requirements:
        1. MUST NOT belong to the original class: "{current_label}"
        2. SHOULD NOT belong to neighbor classes: {unique_neighbor_labels if unique_neighbor_labels else "None"}
        3. {target_instruction}
        4. Make the content as dissimilar as possible from the original text
        5. Keep similar text length (word count)
        6. Generate the most unlikely/different content for the target class

        The goal is to create text that is maximally different from both the original content and the surrounding graph context.
        Only return the modified text without any additional explanation.

        Modified text:"""
        
        prompts.append(prompt)
    
    return prompts


def clean_generated_text(text: str) -> str:
    """Clean and format generated text"""
    text = text.strip()
    
    # Remove quotes if they wrap the entire text
    if (text.startswith('"') and text.endswith('"')) or (text.startswith("'") and text.endswith("'")):
        text = text[1:-1]
    
    # Remove common prefixes
    prefixes = ["Modified text:", "Here is the modified text:", "Modified version:", "Updated text:", "Adversarial text:"]
    for prefix in prefixes:
        if text.lower().startswith(prefix.lower()):
            text = text[len(prefix):].strip()
    
    # Normalize whitespace
    return " ".join(text.split())


async def apply_llm_text_attack(dataset: str, target_nodes: torch.Tensor, texts: List[str], 
                               labels: torch.Tensor, label_names: List[str], 
                               llm_provider: str, model: str, edge_index: torch.Tensor, 
                               all_labels: torch.Tensor, emb_type: str, ptb_rate: float,
                               seed: int = 0, setting: str = "transductive") -> List[str]:
    """Apply LLM-based text attack with database recovery support"""
    print(f"Generating LLM-based adversarial texts for {len(texts)} nodes using {model}...")
    
    # Initialize database
    db = TextAttackDB()
    
    # Generate experiment hash with real parameters
    experiment_hash = db.get_experiment_hash(dataset, llm_provider, model, emb_type, 
                                           ptb_rate, seed, setting)
    
    # Get API configuration
    config = LLM_API_CONFIGS[llm_provider]
    api_key = config["api_key"]
    
    base_urls = [config["base_url"]]
    print(f"Available API endpoints: {base_urls}")
    
    # Create prompts with neighbor awareness
    prompts = create_adversarial_prompts(dataset, target_nodes, texts, labels, label_names, 
                                       edge_index, all_labels)
    
    # Convert target_nodes to list of integers for database storage
    target_node_ids = [int(node.item()) for node in target_nodes]
    
    # Try each base URL until one works
    last_exception = None
    for i, base_url in enumerate(base_urls):
        try:
            print(f"Trying endpoint {i+1}/{len(base_urls)}: {base_url}")
            
            # Generate adversarial texts with database support
            attacked_texts = await generate_adversarial_texts_batch_with_db(
                prompts, target_node_ids, api_key, base_url, model,
                db, experiment_hash, batch_size=10
            )
            
            success_count = sum(1 for orig, att in zip(texts, attacked_texts) if orig != att)
            print(f"✅ Generated {success_count}/{len(texts)} adversarial texts using LLM (endpoint: {base_url})")
            
            # Clean up database after successful completion
            db.cleanup_experiment(experiment_hash)
            
            return attacked_texts
            
        except Exception as e:
            last_exception = e
            print(f"❌ Failed with endpoint {base_url}: {str(e)}")
            
            if i < len(base_urls) - 1:
                print(f"Trying next endpoint...")
                await asyncio.sleep(2)
            else:
                print(f"❌ All endpoints failed. Progress saved to database for recovery.")
                print(f"Run the script again to resume from where it left off.")
    
    # If all URLs failed, raise the last exception
    if last_exception:
        raise last_exception
    else:
        raise Exception("No API endpoints available")


async def apply_llm_text_attack_legacy(dataset: str, target_nodes: torch.Tensor, texts: List[str], 
                                     labels: torch.Tensor, label_names: List[str], 
                                     llm_provider: str, model: str, edge_index: torch.Tensor, 
                                     all_labels: torch.Tensor, seed: int = 0) -> List[str]:
    """Legacy LLM attack function without database support"""
    print(f"Generating LLM-based adversarial texts for {len(texts)} nodes using {model} (legacy mode)...")
    
    # Get API configuration
    config = LLM_API_CONFIGS[llm_provider]
    api_key = config["api_key"]
    
    base_urls = [config["base_url"]]
    print(f"Available API endpoints: {base_urls}")
    
    # Create prompts with neighbor awareness
    prompts = create_adversarial_prompts(dataset, target_nodes, texts, labels, label_names, 
                                       edge_index, all_labels)
    
    # Try each base URL until one works
    last_exception = None
    for i, base_url in enumerate(base_urls):
        try:
            print(f"Trying endpoint {i+1}/{len(base_urls)}: {base_url}")
            
            # Generate adversarial texts
            attacked_texts = await generate_adversarial_texts_batch(
                prompts, api_key, base_url, model, batch_size=10
            )
            
            success_count = sum(1 for orig, att in zip(texts, attacked_texts) if orig != att)
            print(f"✅ Generated {success_count}/{len(texts)} adversarial texts using LLM (endpoint: {base_url})")
            
            return attacked_texts
            
        except Exception as e:
            last_exception = e
            print(f"❌ Failed with endpoint {base_url}: {str(e)}")
            
            if i < len(base_urls) - 1:
                print(f"Trying next endpoint...")
                await asyncio.sleep(2)
            else:
                print(f"❌ All endpoints failed. Last error: {str(e)}")
    
    # If all URLs failed, raise the last exception
    if last_exception:
        raise last_exception
    else:
        raise Exception("No API endpoints available")


def save_llm_attack_examples(dataset: str, target_nodes: torch.Tensor, original_texts: List[str], 
                            attacked_texts: List[str], labels: torch.Tensor, label_names: List[str], 
                            edge_index: torch.Tensor, all_labels: torch.Tensor, seed: int = 0) -> List[Dict[str, Any]]:
    """Save examples of LLM attacks for analysis with neighbor information"""
    examples = []
    num_examples = min(5, len(original_texts))
    
    # Generate prompts for the examples to include in logs
    example_prompts = create_adversarial_prompts(
        dataset, target_nodes[:num_examples], original_texts[:num_examples], 
        labels[:num_examples], label_names, edge_index, all_labels
    )
    
    for i in range(num_examples):
        node_id = target_nodes[i]
        current_label = label_names[labels[i].item()]
        
        # Get neighbor information for this example
        neighbors = edge_index[1][edge_index[0] == node_id]
        if len(neighbors) > 0:
            neighbor_labels = all_labels[neighbors]
            neighbor_label_names = [label_names[idx.item()] for idx in neighbor_labels]
            neighbor_label_counts = {}
            for neighbor_label in neighbor_label_names:
                neighbor_label_counts[neighbor_label] = neighbor_label_counts.get(neighbor_label, 0) + 1
            
            unique_neighbor_labels = list(set(neighbor_label_names))
            neighbor_info = {
                "neighbor_labels": unique_neighbor_labels,
                "neighbor_counts": neighbor_label_counts,
                "total_neighbors": len(neighbors)
            }
        else:
            neighbor_info = {
                "neighbor_labels": [],
                "neighbor_counts": {},
                "total_neighbors": 0
            }
        
        example = {
            "node_id": int(target_nodes[i]),
            "original_text": original_texts[i],
            "attacked_text": attacked_texts[i],
            "label": int(labels[i]),
            "label_name": label_names[labels[i].item()],
            "neighbor_info": neighbor_info,
            "prompt_used": example_prompts[i],
            "changed": original_texts[i] != attacked_texts[i],
            "length_original": len(original_texts[i].split()),
            "length_attacked": len(attacked_texts[i].split()),
            "length_ratio": len(attacked_texts[i].split()) / len(original_texts[i].split()) if len(original_texts[i].split()) > 0 else 0
        }
        examples.append(example)
    
    # Save examples to file
    examples_dir = f"./logs_text_attack/{dataset}/llm_attack_examples"
    os.makedirs(examples_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    examples_file = f"{examples_dir}/examples_seed{seed}_{timestamp}.json"
    
    with open(examples_file, 'w') as f:
        json.dump(examples, f, indent=4)
    
    print(f"✓ Saved attack examples to {examples_file}")
    return examples 