import os
import sqlite3
import json
import hashlib
from typing import List, Dict, Optional, Tuple
from datetime import datetime


class TextAttackDB:
    """Database manager for text attack batch processing with recovery support"""
    
    def __init__(self, db_dir: str = "./text_attack_cache"):
        """Initialize database connection"""
        self.db_dir = db_dir
        os.makedirs(db_dir, exist_ok=True)
        self.db_path = os.path.join(db_dir, "text_attacks.db")
        self.init_database()
    
    def init_database(self):
        """Initialize database tables"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Create table for batch processing status
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS attack_batches (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    experiment_hash TEXT NOT NULL,
                    batch_id INTEGER NOT NULL,
                    batch_size INTEGER NOT NULL,
                    start_idx INTEGER NOT NULL,
                    end_idx INTEGER NOT NULL,
                    status TEXT NOT NULL DEFAULT 'pending',
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    completed_at TIMESTAMP,
                    UNIQUE(experiment_hash, batch_id)
                )
            ''')
            
            # Create table for attacked text results
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS attacked_texts (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    experiment_hash TEXT NOT NULL,
                    batch_id INTEGER NOT NULL,
                    node_id INTEGER NOT NULL,
                    original_text TEXT NOT NULL,
                    attacked_text TEXT NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    UNIQUE(experiment_hash, batch_id, node_id)
                )
            ''')
            
            conn.commit()
    
    def get_experiment_hash(self, dataset: str, llm_provider: str, llm_model: str, 
                          emb_type: str, ptb_rate: float, seed: int, 
                          setting: str = "transductive") -> str:
        """Generate unique hash for experiment parameters"""
        params = f"{dataset}_{llm_provider}_{llm_model}_{emb_type}_{ptb_rate}_{seed}_{setting}"
        return hashlib.md5(params.encode()).hexdigest()
    
    def check_experiment_completion(self, experiment_hash: str) -> bool:
        """Check if experiment is already completed (final log exists)"""
        # Check if final log file exists based on experiment hash
        log_patterns = [
            f"./logs_text_attack/*/llm_*/results_*_{experiment_hash[:8]}_*.json",
            f"./logs_text_attack/*/llm_*/attacked_texts_*_{experiment_hash[:8]}.json"
        ]
        
        import glob
        for pattern in log_patterns:
            if glob.glob(pattern):
                return True
        return False
    
    def init_experiment_batches(self, experiment_hash: str, total_prompts: int, 
                              batch_size: int = 10) -> List[Tuple[int, int, int]]:
        """Initialize batch tracking for an experiment"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Check if batches already exist
            cursor.execute('''
                SELECT batch_id, start_idx, end_idx FROM attack_batches 
                WHERE experiment_hash = ? ORDER BY batch_id
            ''', (experiment_hash,))
            
            existing_batches = cursor.fetchall()
            if existing_batches:
                print(f"✓ Found {len(existing_batches)} existing batches for experiment")
                return existing_batches
            
            # Create new batches
            batches = []
            for i in range(0, total_prompts, batch_size):
                batch_id = i // batch_size
                start_idx = i
                end_idx = min(i + batch_size, total_prompts)
                
                cursor.execute('''
                    INSERT OR IGNORE INTO attack_batches 
                    (experiment_hash, batch_id, batch_size, start_idx, end_idx)
                    VALUES (?, ?, ?, ?, ?)
                ''', (experiment_hash, batch_id, batch_size, start_idx, end_idx))
                
                batches.append((batch_id, start_idx, end_idx))
            
            conn.commit()
            print(f"✓ Initialized {len(batches)} batches for experiment")
            return batches
    
    def get_pending_batches(self, experiment_hash: str) -> List[Tuple[int, int, int]]:
        """Get list of pending (not completed) batches"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT batch_id, start_idx, end_idx FROM attack_batches 
                WHERE experiment_hash = ? AND status != 'completed'
                ORDER BY batch_id
            ''', (experiment_hash,))
            
            pending_batches = cursor.fetchall()
            return pending_batches
    
    def get_completed_batches(self, experiment_hash: str) -> List[Tuple[int, int, int]]:
        """Get list of completed batches"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT batch_id, start_idx, end_idx FROM attack_batches 
                WHERE experiment_hash = ? AND status = 'completed'
                ORDER BY batch_id
            ''', (experiment_hash,))
            
            completed_batches = cursor.fetchall()
            return completed_batches
    
    def mark_batch_started(self, experiment_hash: str, batch_id: int):
        """Mark batch as started"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                UPDATE attack_batches 
                SET status = 'processing' 
                WHERE experiment_hash = ? AND batch_id = ?
            ''', (experiment_hash, batch_id))
            conn.commit()
    
    def save_batch_results(self, experiment_hash: str, batch_id: int, 
                          node_ids: List[int], original_texts: List[str], 
                          attacked_texts: List[str]):
        """Save batch results to database"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            # Save attacked texts
            for node_id, orig_text, att_text in zip(node_ids, original_texts, attacked_texts):
                cursor.execute('''
                    INSERT OR REPLACE INTO attacked_texts 
                    (experiment_hash, batch_id, node_id, original_text, attacked_text)
                    VALUES (?, ?, ?, ?, ?)
                ''', (experiment_hash, batch_id, node_id, orig_text, att_text))
            
            # Mark batch as completed
            cursor.execute('''
                UPDATE attack_batches 
                SET status = 'completed', completed_at = CURRENT_TIMESTAMP
                WHERE experiment_hash = ? AND batch_id = ?
            ''', (experiment_hash, batch_id))
            
            conn.commit()
            print(f"✓ Saved batch {batch_id} with {len(node_ids)} results")
    
    def get_all_attacked_texts(self, experiment_hash: str) -> Dict[int, str]:
        """Get all attacked texts for an experiment, ordered by node_id"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT node_id, attacked_text FROM attacked_texts 
                WHERE experiment_hash = ?
                ORDER BY node_id
            ''', (experiment_hash,))
            
            results = cursor.fetchall()
            return {node_id: attacked_text for node_id, attacked_text in results}
    
    def is_experiment_complete(self, experiment_hash: str) -> bool:
        """Check if all batches for an experiment are completed"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT COUNT(*) as total, 
                       SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed
                FROM attack_batches 
                WHERE experiment_hash = ?
            ''', (experiment_hash,))
            
            total, completed = cursor.fetchone()
            return total > 0 and total == completed
    
    def get_experiment_progress(self, experiment_hash: str) -> Dict:
        """Get progress information for an experiment"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT 
                    COUNT(*) as total_batches,
                    SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed_batches,
                    SUM(CASE WHEN status = 'processing' THEN 1 ELSE 0 END) as processing_batches,
                    SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending_batches
                FROM attack_batches 
                WHERE experiment_hash = ?
            ''', (experiment_hash,))
            
            total, completed, processing, pending = cursor.fetchone()
            
            # Get total attacked texts count
            cursor.execute('''
                SELECT COUNT(*) FROM attacked_texts WHERE experiment_hash = ?
            ''', (experiment_hash,))
            total_texts = cursor.fetchone()[0]
            
            return {
                "total_batches": total or 0,
                "completed_batches": completed or 0,
                "processing_batches": processing or 0,
                "pending_batches": pending or 0,
                "total_attacked_texts": total_texts,
                "completion_percentage": (completed / total * 100) if total > 0 else 0
            }
    
    def cleanup_experiment(self, experiment_hash: str):
        """Clean up database entries for an experiment after successful completion"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('DELETE FROM attacked_texts WHERE experiment_hash = ?', (experiment_hash,))
            cursor.execute('DELETE FROM attack_batches WHERE experiment_hash = ?', (experiment_hash,))
            
            conn.commit()
            print(f"✓ Cleaned up database entries for experiment {experiment_hash[:8]}")
    
    def list_experiments(self) -> List[Dict]:
        """List all experiments in database with their status"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            
            cursor.execute('''
                SELECT DISTINCT experiment_hash FROM attack_batches
            ''')
            
            experiments = []
            for (exp_hash,) in cursor.fetchall():
                progress = self.get_experiment_progress(exp_hash)
                experiments.append({
                    "experiment_hash": exp_hash,
                    "short_hash": exp_hash[:8],
                    **progress
                })
            
            return experiments 