import logging
import os
import sys
import argparse
from tqdm import tqdm
import multiprocessing as mp
from multiprocessing import Queue, Process, Event
import threading
from threading import Thread
import queue
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any, Optional, Tuple
import time
import torch
import yaml
from pathlib import Path
from dataclasses import dataclass, asdict
from datetime import datetime
import json

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from fortress.config import get_config
from config.constants import BENCHMARK_CSVS, BASE_PROJECT_DIR, SETTINGS_PATH

                                  
try:
    from scripts.utils.logging_setup import setup_logging
    setup_logging()
except ImportError as e:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s"
    )
    logging.error(f"Failed to initialize full logging setup, falling back to basicConfig. Error: {e}")

logger = logging.getLogger(__name__)

                                                    
from fortress.core.embedding_model import EmbeddingModel
from fortress.core.nlp_analyzer import NLPAnalyzer
from fortress.common.data_models import DatabasePromptRecord, InputPromptRecord
from fortress.data_management.data_loader import load_prompts_from_csv
from fortress.data_management.prompt_processor import PromptProcessor
from fortress.core.vector_store_interface import ChromaVectorStore
from fortress.common.constants import SPLIT_DATABASE, SPLIT_BENCHMARK
from dataclasses import dataclass, asdict

console = None
try:
    from rich.console import Console
    from rich.table import Table
    from rich.panel import Panel
    from rich.prompt import Prompt, IntPrompt
    from rich.progress import Progress, SpinnerColumn, TextColumn
    console = Console()
except ImportError:
    console = None
    logger.warning("Rich library not installed. Using basic console output.")


                             
EMBEDDING_MODELS = [
    "Qwen/Qwen3-0.6B",
    "Qwen/Qwen3-4B",
    "google/gemma-3-4b-it",
    "google/gemma-3-1b-it",
    "BAAI/bge-m3",
    
]

                                           



@dataclass
class IngestionTask:
    """Data class for ingestion tasks"""
    embedding_model: str
    csv_files: List[Tuple[str, str]]                               
    database_name: str
    database_path: str
    timestamp: str
    status: str = "pending"                                          
    error_message: Optional[str] = None


class IngestionQueueManager:
    """Manages the queue of ingestion tasks, persisting to a JSON file."""
    def __init__(self, queue_file: Path = BASE_PROJECT_DIR / "data/ingestion_queue.json"):
        self.queue_file = queue_file
        self.tasks = self._load_queue()

    def _load_queue(self) -> List[IngestionTask]:
        if not self.queue_file.exists():
            return []
        try:
            with open(self.queue_file, "r") as f:
                data = json.load(f)
            
            tasks = []
            for task_data in data:
                # Convert relative paths back to absolute
                task_data['database_path'] = str(BASE_PROJECT_DIR / task_data['database_path'])
                task_data['csv_files'] = [(name, str(BASE_PROJECT_DIR / path)) for name, path in task_data['csv_files']]
                tasks.append(IngestionTask(**task_data))
            return tasks
        except (json.JSONDecodeError, TypeError) as e:
            logger.error(f"Error loading ingestion queue: {e}")
            # Consider what to do on error - maybe backup the broken file?
            return []

    def _save_queue(self):
        try:
            tasks_to_save = []
            for task in self.tasks:
                task_dict = asdict(task)
                # Convert absolute paths to relative
                try:
                    task_dict['database_path'] = str(Path(task_dict['database_path']).relative_to(BASE_PROJECT_DIR))
                except ValueError:
                    # Path is not within the project, store as is, but log a warning.
                    logger.warning(f"Database path {task_dict['database_path']} is outside the project directory.")
                
                task_dict['csv_files'] = [
                    (name, str(Path(path).relative_to(BASE_PROJECT_DIR)))
                    if BASE_PROJECT_DIR in Path(path).parents else (name, path)
                    for name, path in task_dict['csv_files']
                ]
                tasks_to_save.append(task_dict)

            with open(self.queue_file, "w") as f:
                json.dump(tasks_to_save, f, indent=4)
        except TypeError as e:
            logger.error(f"Error saving ingestion queue: {e}")

    def add_task(self, task: IngestionTask):
        self.tasks.append(task)
        self._save_queue()
    
    def get_pending_tasks(self) -> List[IngestionTask]:
        """Get all pending tasks"""
        return [task for task in self.tasks if task.status == "pending"]
    
    def update_task_status(self, index: int, status: str, error_message: Optional[str] = None):
        """Update the status of a task"""
        if 0 <= index < len(self.tasks):
            self.tasks[index].status = status
            if error_message:
                self.tasks[index].error_message = error_message
            self._save_queue()
    
    def clear_completed(self):
        """Clear completed and failed tasks from queue"""
        self.tasks = [task for task in self.tasks if task.status == "pending" or task.status == "processing"]
        self._save_queue()


def update_settings_embedding_model(model_name: str):
    """Update the embedding model in settings.yaml"""
    try:
        with open(SETTINGS_PATH, 'r') as f:
            settings = yaml.safe_load(f)
        
                                
        if 'embedding_model' not in settings:
            settings['embedding_model'] = {}
        settings['embedding_model']['model_name'] = model_name
        
                               
        with open(SETTINGS_PATH, 'w') as f:
            yaml.dump(settings, f, default_flow_style=False, sort_keys=False)
        
        logger.info(f"Updated settings.yaml with embedding model: {model_name}")
        return True
    except Exception as e:
        logger.error(f"Error updating settings.yaml: {e}")
        return False


def select_embedding_model() -> Optional[str]:
    """Interactive selection of embedding model"""
    if console:
        console.print(Panel("[bold cyan]Select Embedding Model[/bold cyan]", expand=False))
        
        table = Table(show_header=True, header_style="bold magenta")
        table.add_column("Index", style="dim", width=6)
        table.add_column("Model Name", style="cyan")
        
        for i, model in enumerate(EMBEDDING_MODELS, 1):
            table.add_row(str(i), model)
        
        console.print(table)
        
        choice = IntPrompt.ask(
            "[bold]Enter the index of the embedding model",
            choices=[str(i) for i in range(1, len(EMBEDDING_MODELS) + 1)]
        )
        
        return EMBEDDING_MODELS[choice - 1]
    else:
        print("\nAvailable Embedding Models:")
        for i, model in enumerate(EMBEDDING_MODELS, 1):
            print(f"{i}. {model}")
        
        while True:
            try:
                choice = int(input("\nEnter the index of the embedding model: "))
                if 1 <= choice <= len(EMBEDDING_MODELS):
                    return EMBEDDING_MODELS[choice - 1]
                else:
                    print("Invalid choice. Please try again.")
            except ValueError:
                print("Please enter a valid number.")


def select_csv_files() -> List[Tuple[str, str]]:
    """Interactive selection of CSV files"""
    selected_files = []
    
    if console:
        console.print(Panel("[bold cyan]Select CSV Files to Ingest[/bold cyan]", expand=False))
        
        table = Table(show_header=True, header_style="bold magenta")
        table.add_column("Index", style="dim", width=6)
        table.add_column("Dataset Name", style="cyan")
        table.add_column("File Path", style="dim")
        
        csv_list = list(BENCHMARK_CSVS.items())
        for i, (name, path) in enumerate(csv_list, 1):
            table.add_row(str(i), name, path)
        
        console.print(table)
        
        console.print("\n[bold]Enter indices separated by commas (e.g., 1,3,5) or 'all' for all files:[/bold]")
        selection = Prompt.ask("Selection")
        
        if selection.lower() == 'all':
            return csv_list
        else:
            try:
                indices = [int(x.strip()) for x in selection.split(',')]
                for idx in indices:
                    if 1 <= idx <= len(csv_list):
                        selected_files.append(csv_list[idx - 1])
            except ValueError:
                console.print("[red]Invalid selection format.[/red]")
                return []
    else:
        print("\nAvailable CSV Files:")
        csv_list = list(BENCHMARK_CSVS.items())
        for i, (name, path) in enumerate(csv_list, 1):
            print(f"{i}. {name} - {path}")
        
        selection = input("\nEnter indices separated by commas (e.g., 1,3,5) or 'all' for all files: ")
        
        if selection.lower() == 'all':
            return csv_list
        else:
            try:
                indices = [int(x.strip()) for x in selection.split(',')]
                for idx in indices:
                    if 1 <= idx <= len(csv_list):
                        selected_files.append(csv_list[idx - 1])
            except ValueError:
                print("Invalid selection format.")
                return []
    
    return selected_files


def get_database_path() -> Tuple[str, str]:
    """Get database path from user (collection name will come from config)"""
    if console:
        console.print("\n[bold cyan]Database Path Configuration[/bold cyan]")
        console.print("[yellow]Note: Data will be added to the 'fortress_prompts_collection' as configured in settings.yaml[/yellow]")
        
        db_folder_name = Prompt.ask("[bold]Enter database folder name", default="fortress_vector_db")
        
                                                    
        default_path = str(BASE_PROJECT_DIR / f"data/07_vector_db/{db_folder_name}")
        db_path = Prompt.ask("[bold]Enter database path", default=default_path)
    else:
        print("\nNote: Data will be added to the 'fortress_prompts_collection' as configured in settings.yaml")
        
        db_folder_name = input("\nEnter database folder name (default: fortress_vector_db): ").strip()
        if not db_folder_name:
            db_folder_name = "fortress_vector_db"
        
        default_path = str(BASE_PROJECT_DIR / f"data/07_vector_db/{db_folder_name}")
        db_path = input(f"Enter database path (default: {default_path}): ").strip()
        if not db_path:
            db_path = default_path
    
    return db_folder_name, db_path


def display_queue_status(queue_manager: IngestionQueueManager):
    """Display current queue status"""
    if console:
        console.print("\n[bold magenta]Ingestion Queue Status[/bold magenta]")
        
        if not queue_manager.tasks:
            console.print("[yellow]No tasks in queue.[/yellow]")
            return
        
        table = Table(show_header=True, header_style="bold magenta")
        table.add_column("Index", style="dim", width=6)
        table.add_column("Model", style="cyan")
        table.add_column("CSV Files", style="green")
        table.add_column("Database", style="blue")
        table.add_column("Status", style="yellow")
        table.add_column("Timestamp", style="dim")
        
        for i, task in enumerate(queue_manager.tasks, 1):
            csv_names = ", ".join([name for name, _ in task.csv_files])
            status_color = {
                "pending": "yellow",
                "processing": "cyan",
                "completed": "green",
                "failed": "red"
            }.get(task.status, "white")
            
            table.add_row(
                str(i),
                task.embedding_model.split('/')[-1],
                csv_names,
                task.database_name,
                f"[{status_color}]{task.status}[/{status_color}]",
                task.timestamp
            )
        
        console.print(table)
    else:
        print("\n=== Ingestion Queue Status ===")
        if not queue_manager.tasks:
            print("No tasks in queue.")
            return
        
        for i, task in enumerate(queue_manager.tasks, 1):
            csv_names = ", ".join([name for name, _ in task.csv_files])
            print(f"\n{i}. Model: {task.embedding_model}")
            print(f"   CSV Files: {csv_names}")
            print(f"   Database: {task.database_name}")
            print(f"   Status: {task.status}")
            print(f"   Timestamp: {task.timestamp}")


                                                            
def extract_nlp_features_batch_worker(texts: List[str]) -> List[Dict[str, Any]]:
    """Extract NLP features for a batch of texts in a separate process."""
    try:
        from fortress.core.nlp_analyzer import NLPAnalyzer
        nlp_analyzer = NLPAnalyzer()
        results = []
        for text in texts:
            try:
                features = nlp_analyzer.extract_all_features(text)
                results.append(features)
            except Exception as e:
                logger.error(f"Error extracting features for text: {e}")
                results.append({})
        return results
    except Exception as e:
        logger.error(f"Error in NLP batch worker: {e}")
        return [{} for _ in texts]


class OptimizedPipelineDBManager:
    """
    Optimized pipeline-based database manager with parallel DB insertion.
    """
    def __init__(self,
                 embedding_model: EmbeddingModel,
                 vector_store: ChromaVectorStore,
                 collection_name: Optional[str] = None,
                 num_cpu_workers: int = None,
                 gpu_batch_size: int = 1,
                 cpu_batch_size: int = 24,
                 db_batch_size: int = 500,
                 pipeline_buffer_size: int = 5,
                 db_workers: int = 2):
        
        self.embedding_model = embedding_model
        self.vector_store = vector_store
        if collection_name:
            self.vector_store.collection_name = collection_name
        
                             
        self.num_cpu_workers = num_cpu_workers or min(mp.cpu_count() - 4, 20)
        self.gpu_batch_size = gpu_batch_size
        self.cpu_batch_size = cpu_batch_size
        self.db_batch_size = db_batch_size
        self.pipeline_buffer_size = pipeline_buffer_size
        self.db_workers = db_workers
        
                                             
        self.nlp_analyzer = NLPAnalyzer()
        
                                                                       
        self.cpu_input_queue = queue.Queue(maxsize=pipeline_buffer_size)
        self.gpu_input_queue = queue.Queue(maxsize=pipeline_buffer_size)
        self.db_input_queue = queue.Queue(maxsize=pipeline_buffer_size * 2)
        
                       
        self.stop_event = threading.Event()
        self.error_event = threading.Event()
        
                    
        self.stats = {
            'cpu_processed': 0,
            'gpu_processed': 0,
            'db_processed': 0,
            'db_inserted': 0,
            'errors': 0,
            'cpu_time': 0,
            'gpu_time': 0,
            'db_time': 0
        }
        self.stats_lock = threading.Lock()
        
        logger.info(f"OptimizedPipelineDBManager initialized with:")
        logger.info(f"  CPU workers: {self.num_cpu_workers}")
        logger.info(f"  CPU batch size: {self.cpu_batch_size}")
        logger.info(f"  GPU batch size: {self.gpu_batch_size}")
        logger.info(f"  DB batch size: {self.db_batch_size}")
        logger.info(f"  DB workers: {self.db_workers}")
        logger.info(f"  Pipeline buffer size: {self.pipeline_buffer_size}")

    def get_db_size(self) -> int:
        """Returns the number of items in the vector store's collection."""
        return self.vector_store.get_collection_size()

    def _cpu_worker(self):
        """CPU worker that processes NLP features."""
        logger.info("CPU worker started")
        cpu_pool = mp.Pool(processes=self.num_cpu_workers)
        
        try:
            while not self.stop_event.is_set():
                try:
                    batch_data = self.cpu_input_queue.get(timeout=1.0)
                    if batch_data is None:
                        break
                    
                    start_time = time.time()
                    prompts, batch_id = batch_data
                    texts = [p.original_prompt for p in prompts]
                    
                                                      
                    chunk_size = max(1, len(texts) // self.num_cpu_workers)
                    text_chunks = [texts[i:i + chunk_size] for i in range(0, len(texts), chunk_size)]
                    
                    nlp_results = cpu_pool.map(extract_nlp_features_batch_worker, text_chunks)
                    
                    nlp_features = []
                    for chunk_results in nlp_results:
                        nlp_features.extend(chunk_results)
                    
                                  
                    elapsed = time.time() - start_time
                    with self.stats_lock:
                        self.stats['cpu_processed'] += len(prompts)
                        self.stats['cpu_time'] += elapsed
                    
                                       
                    self.gpu_input_queue.put((prompts, nlp_features, batch_id))
                    
                except queue.Empty:
                    continue
                except Exception as e:
                    logger.error(f"Error in CPU worker: {e}", exc_info=True)
                    self.error_event.set()
                    
        finally:
            cpu_pool.close()
            cpu_pool.join()
            logger.info("CPU worker stopped")

    def _gpu_worker(self):
        """GPU worker that generates embeddings."""
        logger.info("GPU worker started")
        
        try:
            while not self.stop_event.is_set():
                try:
                    batch_data = self.gpu_input_queue.get(timeout=1.0)
                    if batch_data is None:
                        break
                    
                    start_time = time.time()
                    prompts, nlp_features, batch_id = batch_data
                    texts = [p.original_prompt for p in prompts]
                    
                                                  
                    embedding_tensor = self.embedding_model.get_embedding(texts)
                    
                    embeddings_list = None
                    if embedding_tensor is not None:
                        embeddings_list = []
                        if embedding_tensor.ndim == 2:
                            for i in range(embedding_tensor.shape[0]):
                                embeddings_list.append(embedding_tensor[i].tolist())
                        else:
                            logger.error(f"Unexpected embedding tensor shape: {embedding_tensor.shape}")
                            embeddings_list = None
                    
                                                             
                    perplexities = [None] * len(texts)
                    
                                  
                    elapsed = time.time() - start_time
                    with self.stats_lock:
                        self.stats['gpu_processed'] += len(prompts)
                        self.stats['gpu_time'] += elapsed
                    
                                            
                    self.db_input_queue.put((prompts, nlp_features, embeddings_list, perplexities, batch_id))
                    
                                                  
                    if batch_id % 20 == 0 and torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                except queue.Empty:
                    continue
                except Exception as e:
                    logger.error(f"Error in GPU worker: {e}", exc_info=True)
                    self.error_event.set()
                    
        finally:
            logger.info("GPU worker stopped")

    def _db_worker(self, worker_id: int, progress_bar):
        """Database worker that creates records and inserts them."""
        logger.info(f"Database worker {worker_id} started")
        
        db_batch = []
        cluster_field_name = get_config().get('clustering', {}).get('cluster_field_name_for_assignment', 'prompt_category')
        
        try:
            while not self.stop_event.is_set() or not self.db_input_queue.empty():
                try:
                    batch_data = self.db_input_queue.get(timeout=1.0)
                    if batch_data is None:
                        break
                    
                    start_time = time.time()
                    prompts, nlp_features, embeddings_list, perplexities, batch_id = batch_data
                    
                                             
                    for i, prompt in enumerate(prompts):
                        try:
                            nlp_feat = nlp_features[i] if i < len(nlp_features) else {}
                            embedding = embeddings_list[i] if embeddings_list and i < len(embeddings_list) else None
                            perplexity = perplexities[i] if i < len(perplexities) else None
                            
                            if embedding is None:
                                logger.error(f"No embedding for prompt ID {prompt.prompt_id}")
                                continue
                            
                                                        
                            nlp_feat['perplexity'] = perplexity
                            
                                           
                            record_data = {
                                **prompt.model_dump(),
                                **nlp_feat,
                                "embedding": embedding,
                            }
                            db_record = DatabasePromptRecord(**record_data)
                            
                            db_batch.append(db_record)
                            
                        except Exception as e:
                            logger.error(f"Error creating DB record: {e}")
                    
                                                             
                    if len(db_batch) >= self.db_batch_size:
                        self._insert_db_batch(db_batch, worker_id)
                        db_batch = []
                    
                                  
                    elapsed = time.time() - start_time
                    with self.stats_lock:
                        self.stats['db_processed'] += len(prompts)
                        self.stats['db_time'] += elapsed
                    progress_bar.update(len(prompts))
                    
                except queue.Empty:
                    if db_batch:
                        self._insert_db_batch(db_batch, worker_id)
                        db_batch = []
                except Exception as e:
                    logger.error(f"Error in DB worker {worker_id}: {e}", exc_info=True)
                    self.error_event.set()
            
                         
            if db_batch:
                self._insert_db_batch(db_batch, worker_id)
                
        finally:
            logger.info(f"Database worker {worker_id} stopped")

    def _insert_db_batch(self, db_batch: List[DatabasePromptRecord], worker_id: int):
        """Insert a batch of records into the database."""
        try:
            insert_start = time.time()
            added_ids, failed_ids = self.vector_store.add_documents(db_batch)
            insert_time = time.time() - insert_start
            
            with self.stats_lock:
                self.stats['db_inserted'] += len(added_ids)
            
            if failed_ids:
                logger.error(f"Worker {worker_id} failed to add {len(failed_ids)} documents")
            
            logger.debug(f"Worker {worker_id} inserted {len(added_ids)} documents in {insert_time:.2f}s")
            
        except Exception as e:
            logger.error(f"Worker {worker_id} error inserting batch to DB: {e}", exc_info=True)

    def ingest_csv_to_db(self, csv_path: str, split_filter: str = SPLIT_DATABASE):
        """Pipeline-based ingestion with concurrent processing."""
        logger.info(f"Starting optimized pipeline ingestion from CSV: {csv_path} for split: '{split_filter}'")
        
                   
        all_input_prompts = load_prompts_from_csv(csv_path)
        if not all_input_prompts:
            logger.warning(f"No prompts loaded from {csv_path}")
            return
        
                        
        prompts_to_process = [p for p in all_input_prompts if p.split == split_filter]
        if not prompts_to_process:
            logger.warning(f"No prompts found for split '{split_filter}'")
            return
        
        num_total = len(prompts_to_process)
        logger.info(f"Processing {num_total} prompts with optimized pipeline")
        
                      
        progress_bar = tqdm(
            total=num_total,
            desc=f"Pipeline {os.path.basename(csv_path)}",
            unit="prompt"
        )
        
                              
        workers = []
        
                                    
        cpu_thread = Thread(target=self._cpu_worker, name="CPU-Worker")
        gpu_thread = Thread(target=self._gpu_worker, name="GPU-Worker")
        workers.extend([cpu_thread, gpu_thread])
        
                             
        for i in range(self.db_workers):
            db_thread = Thread(
                target=self._db_worker, 
                args=(i, progress_bar), 
                name=f"DB-Worker-{i}"
            )
            workers.append(db_thread)
        
                           
        for worker in workers:
            worker.start()
        
                               
        start_time = time.time()
        last_stats_time = start_time
        
        try:
                                        
            batch_id = 0
            for i in range(0, num_total, self.cpu_batch_size):
                if self.error_event.is_set():
                    logger.error("Error detected, stopping pipeline")
                    break
                
                batch = prompts_to_process[i:i + self.cpu_batch_size]
                self.cpu_input_queue.put((batch, batch_id))
                batch_id += 1
                
                                           
                current_time = time.time()
                if current_time - last_stats_time > 15:
                    self._report_stats(current_time - start_time)
                    last_stats_time = current_time
            
                               
            self.cpu_input_queue.put(None)
            
                                           
            cpu_thread.join()
            self.gpu_input_queue.put(None)
            gpu_thread.join()
            
                                                  
            for _ in range(self.db_workers):
                self.db_input_queue.put(None)
            
                                     
            for worker in workers[2:]:
                worker.join()
            
        except Exception as e:
            logger.error(f"Error in pipeline: {e}", exc_info=True)
            self.stop_event.set()
        finally:
            progress_bar.close()
            
                     
        total_time = time.time() - start_time
        self._report_final_stats(total_time, num_total)

    def _report_stats(self, elapsed_time):
        """Report pipeline statistics."""
        with self.stats_lock:
            cpu_rate = self.stats['cpu_processed'] / elapsed_time if elapsed_time > 0 else 0
            gpu_rate = self.stats['gpu_processed'] / elapsed_time if elapsed_time > 0 else 0
            db_rate = self.stats['db_processed'] / elapsed_time if elapsed_time > 0 else 0
            
            logger.info(f"Rates - CPU: {cpu_rate:.1f}/s, GPU: {gpu_rate:.1f}/s, DB: {db_rate:.1f}/s")
            
                          
            cpu_queue_pct = (self.cpu_input_queue.qsize() / self.cpu_input_queue.maxsize * 100) if self.cpu_input_queue.maxsize > 0 else 0
            gpu_queue_pct = (self.gpu_input_queue.qsize() / self.gpu_input_queue.maxsize * 100) if self.gpu_input_queue.maxsize > 0 else 0
            db_queue_pct = (self.db_input_queue.qsize() / self.db_input_queue.maxsize * 100) if self.db_input_queue.maxsize > 0 else 0
            
            logger.info(f"Queue usage - CPU: {cpu_queue_pct:.0f}%, GPU: {gpu_queue_pct:.0f}%, DB: {db_queue_pct:.0f}%")

    def _report_final_stats(self, total_time, num_total):
        """Report final statistics."""
        with self.stats_lock:
            logger.info(f"\n{'='*50}")
            logger.info(f"Pipeline completed in {total_time:.2f} seconds")
            logger.info(f"Total throughput: {num_total/total_time:.1f} prompts/second")
            logger.info(f"CPU processed: {self.stats['cpu_processed']}")
            logger.info(f"GPU processed: {self.stats['gpu_processed']}")
            logger.info(f"DB processed: {self.stats['db_processed']}")
            logger.info(f"DB inserted: {self.stats['db_inserted']}")
            logger.info(f"{'='*50}\n")


def write_database_readme(database_path: str, database_name: str, embedding_model: str, csv_files: List[Tuple[str, str]], timestamp: str):
    """
    Write a README.md file in the database folder documenting the embedding model and CSVs used.
    """
    readme_path = Path(database_path) / "README.md"
    lines = [
        f"# Database: {database_name}\n",
        f"**Created:** {timestamp}\n",
        f"**Embedding Model:** `{embedding_model}`\n",
        "\n## Source CSV Files\n",
    ]
    for name, path in csv_files:
        lines.append(f"- **{name}**: `{path}`")
    lines.append("\n---\n")
    lines.append("This database was generated by the FORTRESS data ingestion pipeline.\n")
    readme_path.parent.mkdir(parents=True, exist_ok=True)
    with open(readme_path, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))


                                        
def update_embedding_model_in_settings(model_name: str, settings_path: Path = SETTINGS_PATH) -> bool:
    """
    Update the embedding model and vector database path in settings.yaml file
    """
    try:
        with open(settings_path, 'r') as f:
            settings = yaml.safe_load(f)
        if 'embedding_model' not in settings:
            settings['embedding_model'] = {}
        settings['embedding_model']['model_name'] = model_name
                              
        if 'device' not in settings['embedding_model']:
            settings['embedding_model']['device'] = 'auto'
                                                        
        model_identifier = model_name.replace('/', '_').replace('-', '_').lower()
        if 'vector_database' not in settings:
            settings['vector_database'] = {}
        base_vector_db_path = BASE_PROJECT_DIR / "data" / "07_vector_db"
        settings['vector_database']['path'] = str(base_vector_db_path / model_identifier)
        with open(settings_path, 'w') as f:
            yaml.dump(settings, f, default_flow_style=False, sort_keys=False)
        logger.info(f"Updated settings.yaml with embedding model: {model_name}")
        return True
    except Exception as e:
        logger.error(f"Error updating embedding model: {e}")
        return False

                                           
def backup_settings_yaml(settings_path: Path = SETTINGS_PATH) -> Path:
    backup_path = settings_path.parent / (settings_path.name + ".backup")
    import shutil
    shutil.copy2(settings_path, backup_path)
    return backup_path

def restore_settings_yaml(backup_path: Path, settings_path: Path = SETTINGS_PATH):
    import shutil
    if backup_path.exists():
        shutil.copy2(backup_path, settings_path)


def force_reload_config():
    """Force reload the configuration module to pick up settings changes"""
    import importlib
    from fortress import config as config_module
                            
    if hasattr(config_module, '_cached_config'):
        delattr(config_module, '_cached_config')
                       
    importlib.reload(config_module)
    logger.info("Configuration reloaded")


def process_ingestion_queue(queue_manager: IngestionQueueManager):
    """Process all pending tasks in the ingestion queue"""
    pending_tasks = queue_manager.get_pending_tasks()
    
    if not pending_tasks:
        if console:
            console.print("[yellow]No pending tasks in queue.[/yellow]")
        else:
            print("No pending tasks in queue.")
        return
    
    if console:
        console.print(f"\n[bold green]Processing {len(pending_tasks)} pending tasks...[/bold green]")
    else:
        print(f"\nProcessing {len(pending_tasks)} pending tasks...")
    
    for i, task in enumerate(queue_manager.tasks):
        if task.status != "pending":
            continue
        
        try:
                                
            queue_manager.update_task_status(i, "processing")
            
            if console:
                console.print(f"\n[bold cyan]Processing Task {i+1}/{len(pending_tasks)}[/bold cyan]")
                console.print(f"Model: {task.embedding_model}")
                console.print(f"Database: {task.database_name}")
            else:
                print(f"\nProcessing Task {i+1}/{len(pending_tasks)}")
                print(f"Model: {task.embedding_model}")
                print(f"Database: {task.database_name}")
            
                                                                
            backup_path = backup_settings_yaml()
            if not update_embedding_model_in_settings(task.embedding_model):
                raise Exception("Failed to update settings.yaml")
            
                                                                   
            force_reload_config()
            config = get_config()
            
                                                         
            vector_db_config = config.get('vector_database', {})
            collection_name = vector_db_config.get('collection_name', 'fortress_prompts_collection')
            
                                                            
            embedding_model = EmbeddingModel()
            
                                                                           
                                                                                    
            vector_store = ChromaVectorStore(
                collection_name=collection_name,                                                      
                db_path=task.database_path
            )
            
            logger.info(f"Using collection: {collection_name} at path: {task.database_path}")
            
                                     
            pipeline_manager = OptimizedPipelineDBManager(
                embedding_model=embedding_model,
                vector_store=vector_store,
                gpu_batch_size=1,
                cpu_batch_size=12,
                db_batch_size=500,
                pipeline_buffer_size=5,
                db_workers=2
            )
            
            initial_size = pipeline_manager.get_db_size()
            logger.info(f"Initial collection '{collection_name}' size: {initial_size}")
            
                                   
            for csv_name, csv_path in task.csv_files:
                if console:
                    console.print(f"\nIngesting: {csv_name}")
                else:
                    print(f"\nIngesting: {csv_name}")
                
                if not Path(csv_path).exists():
                    logger.error(f"CSV file not found: {csv_path}")
                    continue
                
                pipeline_manager.ingest_csv_to_db(csv_path, split_filter=SPLIT_DATABASE)
            
            final_size = pipeline_manager.get_db_size()
            
            if console:
                console.print(f"[green]Task completed. Added {final_size - initial_size} records to '{collection_name}'.[/green]")
            else:
                print(f"Task completed. Added {final_size - initial_size} records to '{collection_name}'.")

                                                      
            write_database_readme(
                database_path=task.database_path,
                database_name=collection_name,                              
                embedding_model=task.embedding_model,
                csv_files=task.csv_files,
                timestamp=task.timestamp
            )
            
                                
            queue_manager.update_task_status(i, "completed")
            
        except Exception as e:
            logger.error(f"Error processing task: {e}", exc_info=True)
            queue_manager.update_task_status(i, "failed", str(e))
            if console:
                console.print(f"[red]Task failed: {e}[/red]")
            else:
                print(f"Task failed: {e}")
        finally:
                                                      
            restore_settings_yaml(backup_path)


def interactive_menu():
    """Interactive menu for managing ingestion tasks"""
    queue_manager = IngestionQueueManager()
    
    while True:
        if console:
            console.clear()
            console.print(Panel(
                "[bold magenta]FORTRESS Data Ingestion Manager[/bold magenta]\n\n"
                "[cyan]Systematic data ingestion with queue management[/cyan]",
                expand=False,
                border_style="magenta"
            ))
            console.rule()
            
            console.print("[bold]Menu Options:[/bold]\n")
            console.print("  [cyan]1.[/cyan] Add new ingestion task")
            console.print("  [cyan]2.[/cyan] View queue status")
            console.print("  [cyan]3.[/cyan] Process pending tasks")
            console.print("  [cyan]4.[/cyan] Clear completed tasks")
            console.print("  [cyan]5.[/cyan] Exit\n")
            
            choice = Prompt.ask(
                "[bold]Enter your choice",
                choices=["1", "2", "3", "4", "5"],
                default="1"
            )
        else:
            print("\n=== FORTRESS Data Ingestion Manager ===")
            print("\nMenu Options:")
            print("1. Add new ingestion task")
            print("2. View queue status")
            print("3. Process pending tasks")
            print("4. Clear completed tasks")
            print("5. Exit")
            
            choice = input("\nEnter your choice (1-5): ")
        
        if choice == "1":
                                    
            model = select_embedding_model()
            if not model:
                continue
            
            csv_files = select_csv_files()
            if not csv_files:
                continue
            
            db_name, db_path = get_database_path()
            
                         
            task = IngestionTask(
                embedding_model=model,
                csv_files=csv_files,
                database_name=db_name,
                database_path=db_path,
                timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            )
            
            queue_manager.add_task(task)
            
            if console:
                console.print("\n[green]Task added to queue successfully![/green]")
                console.print("[dim]Press Enter to continue...[/dim]")
            else:
                print("\nTask added to queue successfully!")
                print("Press Enter to continue...")
            input()
            
        elif choice == "2":
                               
            display_queue_status(queue_manager)
            if console:
                console.print("\n[dim]Press Enter to continue...[/dim]")
            else:
                print("\nPress Enter to continue...")
            input()
            
        elif choice == "3":
                                   
            process_ingestion_queue(queue_manager)
            if console:
                console.print("\n[dim]Press Enter to continue...[/dim]")
            else:
                print("\nPress Enter to continue...")
            input()
            
        elif choice == "4":
                                   
            queue_manager.clear_completed()
            if console:
                console.print("\n[green]Completed tasks cleared from queue.[/green]")
                console.print("[dim]Press Enter to continue...[/dim]")
            else:
                print("\nCompleted tasks cleared from queue.")
                print("Press Enter to continue...")
            input()
            
        elif choice == "5":
                  
            if console:
                console.print("\n[bold green]Thank you for using FORTRESS Data Ingestion Manager![/bold green]")
            else:
                print("\nThank you for using FORTRESS Data Ingestion Manager!")
            break


def main():
    parser = argparse.ArgumentParser(description="FORTRESS Data Ingestion with automatic queue management")
    parser.add_argument("--interactive", "-i", action="store_true", help="Run in interactive mode")
    parser.add_argument("--process-queue", action="store_true", help="Process pending tasks in queue")
    parser.add_argument("--view-queue", action="store_true", help="View current queue status")
    parser.add_argument("csv_files", nargs='*', help="CSV files to process (for legacy compatibility)")
    
    args = parser.parse_args()
    
    if args.interactive or (not args.process_queue and not args.view_queue and not args.csv_files):
                              
        interactive_menu()
    elif args.view_queue:
                           
        queue_manager = IngestionQueueManager()
        display_queue_status(queue_manager)
    elif args.process_queue:
                       
        queue_manager = IngestionQueueManager()
        process_ingestion_queue(queue_manager)
    elif args.csv_files:
                                              
        logger.info("Running in legacy mode")
        
        try:
            config = get_config()
            
                                   
            logger.info("Initializing embedding model...")
            embedding_model = EmbeddingModel()
            
                                     
            vector_db_config = config.get('vector_database', {})
            chroma_db_path = vector_db_config.get('path', "data/07_vector_db/chroma_persistent")
            chroma_collection_name = vector_db_config.get('collection_name', 'fortress_prompts_collection')
            
            if chroma_db_path != ":memory:":
                os.makedirs(os.path.dirname(chroma_db_path) or chroma_db_path, exist_ok=True)
            
            logger.info(f"Initializing ChromaVectorStore: {chroma_collection_name}")
            vector_store = ChromaVectorStore(collection_name=chroma_collection_name, db_path=chroma_db_path)
            
                                     
            pipeline_manager = OptimizedPipelineDBManager(
                embedding_model=embedding_model,
                vector_store=vector_store,
                gpu_batch_size=1,
                cpu_batch_size=12,
                db_batch_size=500,
                pipeline_buffer_size=5,
                db_workers=2
            )
            
            initial_size = pipeline_manager.get_db_size()
            logger.info(f"Initial DB size: {initial_size}")
            
                           
            for csv_file in args.csv_files:
                logger.info(f"Processing: {csv_file}")
                pipeline_manager.ingest_csv_to_db(csv_file, split_filter=SPLIT_DATABASE)
            
            final_size = pipeline_manager.get_db_size()
            logger.info(f"Final DB size: {final_size} (added {final_size - initial_size})")
            
        except Exception as e:
            logger.error(f"Fatal error: {e}", exc_info=True)
            sys.exit(1)


if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    main()
