import argparse
import logging
import os
import random
import sys
from typing import List, Dict, Any

                                                                             
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from fortress.config import get_config
from fortress.core.vector_store_interface import ChromaVectorStore
from fortress.common.constants import LABEL_SAFE, LABEL_UNSAFE

                     
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def chunker(seq, size):
    """Yield successive n-sized chunks from a sequence."""
    for pos in range(0, len(seq), size):
        yield seq[pos:pos + size]

def inject_noise(noise_level: float, db_path: str = None, collection_name: str = None, seed: int = 42):
    """
    Connects to ChromaDB, selects a percentage of documents from the 'database'
    split, flips their labels, and updates them in safe-sized batch operations.
    """
    if not (0.0 <= noise_level <= 1.0):
        logger.error("Noise level must be between 0.0 and 1.0.")
        raise ValueError("Noise level must be between 0.0 and 1.0.")

    random.seed(seed)
    logger.info(f"Starting label noise injection with a noise level of {noise_level:.2%} and seed {seed}.")

    try:
                                                                 
        vector_store = ChromaVectorStore(collection_name=collection_name, db_path=db_path)
        logger.info(f"Successfully connected to collection '{vector_store.collection_name}' at '{vector_store.db_path}'.")
    except Exception as e:
        logger.error(f"Failed to initialize ChromaVectorStore: {e}", exc_info=True)
        return

                                                              
    try:
        logger.info("Fetching all document IDs and metadata from the 'database' split...")
        db_split_docs = vector_store.collection.get(
            where={"split": "database"},
            include=["metadatas"]
        )
    except Exception as e:
        logger.error(f"Failed to fetch documents with 'split: database' filter: {e}", exc_info=True)
        logger.error("Please ensure your ingested data includes a 'split' metadata field.")
        return

    if not db_split_docs or not db_split_docs.get('ids'):
        logger.warning("No documents found in the 'database' split. No noise will be injected.")
        return

    logger.info(f"Found {len(db_split_docs['ids'])} documents in the 'database' split.")

                                                                         
    num_to_flip = int(len(db_split_docs['ids']) * noise_level)
    if num_to_flip == 0:
        logger.info("Noise level is too low to select any documents to flip. Exiting.")
        return

    all_docs = list(zip(db_split_docs['ids'], db_split_docs['metadatas']))
    docs_to_flip = random.sample(all_docs, num_to_flip)
    
    logger.info(f"Selected {len(docs_to_flip)} documents to flip their labels.")

                                         
    ids_to_update: List[str] = []
    metadatas_to_update: List[Dict[str, Any]] = []

    for doc_id, metadata in docs_to_flip:
        current_label = metadata.get('label')
        if current_label not in [LABEL_SAFE, LABEL_UNSAFE]:
            logger.warning(f"Skipping doc {doc_id}: invalid or missing label '{current_label}'.")
            continue

        metadata['label'] = LABEL_UNSAFE if current_label == LABEL_SAFE else LABEL_SAFE
        ids_to_update.append(doc_id)
        metadatas_to_update.append(metadata)

                                                
    if not ids_to_update:
        logger.warning("No valid documents were selected for label flipping. Exiting.")
        return
        
                                                                          
    BATCH_SIZE = 4096 
    
    total_updated_count = 0
    try:
        num_batches = (len(ids_to_update) + BATCH_SIZE - 1) // BATCH_SIZE
        logger.info(f"Performing batch update on {len(ids_to_update)} documents in {num_batches} chunks of max size {BATCH_SIZE}...")

                                                      
        id_chunks = chunker(ids_to_update, BATCH_SIZE)
        metadata_chunks = chunker(metadatas_to_update, BATCH_SIZE)

        for i, (batch_ids, batch_metadatas) in enumerate(zip(id_chunks, metadata_chunks)):
            logger.info(f"Updating batch {i + 1}/{num_batches} with {len(batch_ids)} documents...")
            vector_store.collection.update(
                ids=batch_ids,
                metadatas=batch_metadatas
            )
            total_updated_count += len(batch_ids)

        logger.info(f"Successfully updated labels for {total_updated_count} documents.")
    except Exception as e:
        logger.error(f"An error occurred during the batch update process: {e}", exc_info=True)

    logger.info("Noise injection process complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Inject label noise into the Fortress ChromaDB for robustness testing.")
    parser.add_argument(
        "--noise-level",
        type=float,
        required=True,
        help="The proportion of database labels to flip (e.g., 0.05 for 5%)."
    )
    parser.add_argument(
        "--db-path",
        type=str,
        default=None,
        help="Path to the ChromaDB directory. If not provided, uses the path from settings.yaml."
    )
    parser.add_argument(
        "--collection-name",
        type=str,
        default=None,
        help="Name of the ChromaDB collection. If not provided, uses the name from settings.yaml."
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility."
    )
    args = parser.parse_args()

    inject_noise(
        noise_level=args.noise_level,
        db_path=args.db_path,
        collection_name=args.collection_name,
        seed=args.seed
    )
