import argparse
import json                                                        
import logging
import os
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Dict, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing
import queue
import threading

from pydantic import BaseModel, Field

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

                        
try:
    pass                                                
except ImportError:
    print("scikit-learn is not installed. Please install it to run the benchmark suite: pip install scikit-learn")
    sys.exit(1)

                           
from fortress.config import get_config
from fortress.core.embedding_model import EmbeddingModel
from fortress.core.nlp_analyzer import NLPAnalyzer
from fortress.core.vector_store_interface import ChromaVectorStore
from fortress.data_management.prompt_processor import PromptProcessor
from fortress.detection_pipeline.primary_detector import PrimaryDetector
from fortress.detection_pipeline.secondary_analyzer import SecondaryAnalyzer
from fortress.detection_pipeline.main_pipeline import DetectionPipeline
from fortress.common.data_models import (
    FinalDetectionOutput,
    DecisionLabel as FortressDecisionLabel,                                  
)
from fortress.data_management.data_loader import load_prompts_from_csv
from fortress.common.constants import SPLIT_BENCHMARK, LABEL_SAFE, LABEL_UNSAFE

                             
from scripts.utils.benchmark_utils import (
    setup_logger,
    BaseBenchmarkSingleResult,
    CommonBenchmarkMetrics,
    add_common_benchmark_args,
    create_run_id,
    prepare_output_paths,
    load_and_filter_prompts,
    map_true_label_to_str,
    calculate_common_metrics,
    generate_markdown_report_parts,
    save_json_output,
    save_markdown_report,
    get_progress_bar,
    InputPromptRecord                                        
)

class FortressBenchmarkSingleResult(BaseBenchmarkSingleResult):
    confidence: Optional[float] = None
    is_ambiguous_pred: Optional[bool] = None
    justification_pred: Optional[str] = None
                            

class FortressBenchmarkSuiteResults(BaseModel):
    suite_run_id: str
    timestamp_start: str
    timestamp_end: str
    duration_seconds: float
    input_csv_files: List[str]
    output_results_file: str
    output_report_file: str
    config_snapshot: Dict[str, Any]                    
    metrics: CommonBenchmarkMetrics
    individual_results: List[FortressBenchmarkSingleResult]




def process_prompt_batch(detection_pipeline, prompt_batch: List[InputPromptRecord], batch_idx: int, logger) -> List[FortressBenchmarkSingleResult]:
    """Process a batch of prompts and return results"""
    batch_results = []
    
    for i, prompt_record in enumerate(prompt_batch):
        true_label_str = map_true_label_to_str(
            prompt_record.label, logger,
            safe_numeric_val=LABEL_SAFE, unsafe_numeric_val=LABEL_UNSAFE
        )
        if true_label_str == "N/A":
            logger.warning(f"Skipping prompt ID {prompt_record.prompt_id or f'batch{batch_idx}_idx{i}'} due to unmappable true label: {prompt_record.label}")
            continue

        start_time = time.perf_counter()
        detection_output: Optional[FinalDetectionOutput] = None
        current_error_info: Optional[str] = None
        
        try:
            detection_output = detection_pipeline.run(prompt_record.original_prompt)
        except Exception as e:
            logger.error(f"Error running Fortress pipeline for prompt ID {prompt_record.prompt_id or f'batch{batch_idx}_idx{i}'}: {e}", exc_info=True)
            current_error_info = str(e)
            detection_output = FinalDetectionOutput(
                query_text=prompt_record.original_prompt,
                final_decision=FortressDecisionLabel.ERROR,
                error_info=current_error_info
            )
        
        end_time = time.perf_counter()
        processing_time_ms = (end_time - start_time) * 1000

        single_result = FortressBenchmarkSingleResult(
            prompt_id=str(prompt_record.prompt_id) if prompt_record.prompt_id else f"batch{batch_idx}_idx{i}",
            original_prompt=prompt_record.original_prompt,
            true_label=true_label_str,
            predicted_label=detection_output.final_decision.value if detection_output else FortressDecisionLabel.ERROR.value,
            confidence=detection_output.overall_confidence if detection_output else None,
            is_ambiguous_pred=detection_output.is_ambiguous if detection_output else None,
            justification_pred=detection_output.justification if detection_output else None,
            source_file_input=prompt_record.source_file,
            prompt_category_input=prompt_record.prompt_category,
            prompt_style_input=prompt_record.prompt_style,
            processing_time_ms=processing_time_ms,
            error_info=detection_output.error_info if detection_output and detection_output.error_info else current_error_info
        )
        batch_results.append(single_result)
    
    return batch_results





def main():
    parser = argparse.ArgumentParser(description="Run FORTRESS Detection Pipeline Benchmark Suite.")
    parser = add_common_benchmark_args(parser)
    
                                 
    parser.add_argument(
        "--collection-name",
        type=str,
        default=None,
        help="Name of the ChromaDB collection to use. Uses default from settings if not provided."
    )
    
                                       
    parser.add_argument(
        "--prefetch-size",
        type=int,
        default=10,
        help="Number of prompts to prefetch for GPU processing. Default: 10"
    )
    
                             
    parser.add_argument(
        "--run-id",
        type=str,
        default=None,
        help="Custom run ID for this benchmark run. If not provided, a run ID will be generated."
    )
    parser.add_argument(
        "--gpu-batch-size", 
        type=int, 
        default=6, 
        help="Batch size for GPU embedding processing"
    )

    args = parser.parse_args()

    logger = setup_logger(__name__, level_str=args.log_level)

    timestamp_start_dt = datetime.now()
                                                     
    if args.run_id:
        suite_run_id = create_run_id(base_name=args.run_id, prefix=args.run_name_prefix)
    else:
        suite_run_id = create_run_id(base_name="fortress_gemma4b_toxic_chat_extension", prefix=args.run_name_prefix)
    
    output_results_file, output_report_file, _ = prepare_output_paths(
        args.output_dir, suite_run_id
    )

    logger.info(f"Benchmark Run ID: {suite_run_id}")
    logger.info(f"Input CSVs: {args.input_csvs}")
    logger.info(f"Outputting results to: {output_results_file}")
    logger.info(f"Outputting report to: {output_report_file}")

                                                              
    try:
        logger.info("Initializing FORTRESS components...")
        config = get_config() 

        embedding_model = EmbeddingModel()
        nlp_analyzer = NLPAnalyzer()
        
        vector_store_params = {'collection_name': args.collection_name} if args.collection_name else {}
        vector_store = ChromaVectorStore(**vector_store_params)
        
        prompt_processor = PromptProcessor(embedding_model, nlp_analyzer, vector_store)
        
        primary_detector = PrimaryDetector(prompt_processor, vector_store)
        secondary_analyzer = SecondaryAnalyzer()
        
        detection_pipeline = DetectionPipeline(
            prompt_processor=prompt_processor,
            primary_detector=primary_detector,
            secondary_analyzer=secondary_analyzer
        )
        logger.info("FORTRESS components initialized successfully.")
    except Exception as e:
        logger.critical(f"Failed to initialize FORTRESS components: {e}", exc_info=True)
        sys.exit(1)

                                      
    all_prompts: List[InputPromptRecord] = load_and_filter_prompts(
        csv_file_paths=args.input_csvs,
        split_to_filter=SPLIT_BENCHMARK,
        load_prompts_func=load_prompts_from_csv,
        logger=logger
    )

    if not all_prompts:
        logger.warning("No benchmark prompts found or processed. Exiting.")
        sys.exit(0)
    
    total_prompts = len(all_prompts)                           
    individual_results: List[FortressBenchmarkSingleResult] = []
    results_lock = threading.Lock()
    
                                             
    process_queue = queue.Queue(maxsize=args.prefetch_size)
    
    def producer(prompts):
        """Producer thread that queues prompts for processing"""
        for i, prompt_record in enumerate(prompts):
            true_label_str = map_true_label_to_str(
                prompt_record.label, logger,
                safe_numeric_val=LABEL_SAFE, unsafe_numeric_val=LABEL_UNSAFE
            )
            if true_label_str == "N/A":
                logger.warning(f"Skipping prompt ID {prompt_record.prompt_id or f'idx_{i}'} due to unmappable true label: {prompt_record.label}")
                continue
            
                                             
            process_queue.put((i, prompt_record, true_label_str))
        
                               
        process_queue.put(None)
    
    def consumer(detection_pipeline_instance, current_gpu_batch_size, lock, results_list, queue_instance, progress_bar):
        
        def process_and_store_batch(batch_to_process):
            if not batch_to_process:
                return

            batch_prompt_records = [item[1] for item in batch_to_process]
            batch_true_labels = {item[1].prompt_id : item[2] for item in batch_to_process}
            batch_indices = [item[0] for item in batch_to_process]

            start_time_batch = time.perf_counter()
            batched_detection_outputs = detection_pipeline_instance.run_batch(batch_prompt_records)
            end_time_batch = time.perf_counter()
            
            batch_processing_time_ms = (end_time_batch - start_time_batch) * 1000
            processing_time_per_prompt_ms = batch_processing_time_ms / len(batch_prompt_records) if batch_prompt_records else 0

            for i, detection_output in enumerate(batched_detection_outputs):
                prompt_record = batch_prompt_records[i]
                true_label_str = batch_true_labels.get(prompt_record.prompt_id)
                
                single_result = FortressBenchmarkSingleResult(
                    prompt_id=str(prompt_record.prompt_id) if prompt_record.prompt_id else f"idx_{batch_indices[i]}",
                    original_prompt=prompt_record.original_prompt,
                    true_label=true_label_str,
                    predicted_label=detection_output.final_decision.value,
                    confidence=detection_output.overall_confidence,
                    is_ambiguous_pred=detection_output.is_ambiguous,
                    justification_pred=detection_output.justification,
                    source_file_input=prompt_record.source_file,
                    prompt_category_input=prompt_record.prompt_category,
                    prompt_style_input=prompt_record.prompt_style,
                    processing_time_ms=processing_time_per_prompt_ms,
                    error_info=detection_output.error_info
                )
                with lock:
                    results_list.append(single_result)
                if progress_bar:
                    progress_bar.update(1)

        prompt_batch_for_gpu = []
        while True:
            try:
                                                    
                while len(prompt_batch_for_gpu) < current_gpu_batch_size:
                    item = queue_instance.get(timeout=0.1)
                    if item is None:
                        queue_instance.put(None)
                        if prompt_batch_for_gpu:
                            break
                        return
                    prompt_batch_for_gpu.append(item)
                
                if not prompt_batch_for_gpu:
                    continue

                process_and_store_batch(prompt_batch_for_gpu)
                prompt_batch_for_gpu = []

            except queue.Empty:
                if prompt_batch_for_gpu:                                            
                    process_and_store_batch(prompt_batch_for_gpu)
                    prompt_batch_for_gpu = []                         
                continue                                                        
            except Exception as e:
                logger.error(f"Exception in consumer thread: {e}", exc_info=True)
                prompt_batch_for_gpu = []
                continue
    
    producer_thread = threading.Thread(target=producer, args=(all_prompts,))
    
                                                                                       
                                     
    consumer_thread = None 

    producer_thread.start()
    
                                        
    with get_progress_bar(range(total_prompts), desc="Processing prompts (Fortress)", logger=logger, total=total_prompts) as pbar_obj:
        consumer_thread = threading.Thread(target=consumer, args=(
            detection_pipeline, 
            args.gpu_batch_size, 
            results_lock, 
            individual_results, 
            process_queue, 
            pbar_obj 
        ))
        consumer_thread.start()

                                                    
        while producer_thread.is_alive() or (consumer_thread is not None and consumer_thread.is_alive()):
            time.sleep(0.1)
                                                                      
            pass

    producer_thread.join()
    if consumer_thread is not None:                                     
        consumer_thread.join()
    
                                                 
    individual_results.sort(key=lambda x: (x.prompt_id is None, x.prompt_id))

                               
    logger.info("Calculating benchmark metrics...")
    final_metrics = calculate_common_metrics(
        individual_results, logger,
        positive_label=FortressDecisionLabel.UNSAFE.value,
        negative_label=FortressDecisionLabel.SAFE.value,
        error_label=FortressDecisionLabel.ERROR.value,
        ambiguous_label=FortressDecisionLabel.AMBIGUOUS.value 
    )
    
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(f"Calculated Metrics: {final_metrics.model_dump_json(indent=2)}")

                          
    timestamp_end_dt = datetime.now()
    duration_seconds = (timestamp_end_dt - timestamp_start_dt).total_seconds()

    serializable_config = {}
    if isinstance(config, dict):
        try:
            json.dumps(config)
            serializable_config = config
        except TypeError:
            logger.warning("Config object is not directly JSON serializable. Storing as string representation.")
            serializable_config = {"config_string": str(config)}
    else:
        try:
            if hasattr(config, 'model_dump'): 
                serializable_config = config.model_dump(mode='json')
            elif hasattr(config, 'dict'): 
                serializable_config = config.dict()
            else: 
                serializable_config = {"config_string": str(config)}
        except Exception as e:
            logger.warning(f"Could not serialize config object: {e}. Storing as string.")
            serializable_config = {"config_string": str(config)}

    serializable_config["benchmark_optimization"] = {
        "prefetch_size": args.prefetch_size,
        "gpu_batch_size": args.gpu_batch_size, 
        "pipelined_processing": True 
    }

    suite_results_data = FortressBenchmarkSuiteResults(
        suite_run_id=suite_run_id,
        timestamp_start=timestamp_start_dt.isoformat(),
        timestamp_end=timestamp_end_dt.isoformat(),
        duration_seconds=duration_seconds,
        input_csv_files=args.input_csvs,
        output_results_file=str(output_results_file.resolve()),
        output_report_file=str(output_report_file.resolve()),
        config_snapshot=serializable_config,
        metrics=final_metrics,
        individual_results=individual_results
    )

    save_json_output(suite_results_data, output_results_file, logger)

                                              
    model_specific_desc = f"FORTRESS Detection Pipeline (Pipelined with prefetch={args.prefetch_size}, gpu_batch_size={args.gpu_batch_size})"
    config_snapshot_str = json.dumps(serializable_config, indent=2)
    model_specific_config_lines = [f"\n**Configuration Snapshot:**\n```json\n{config_snapshot_str}\n```"]
    
    report_content = generate_markdown_report_parts(
        suite_results=suite_results_data,
        model_specific_desc=model_specific_desc,
        model_specific_config_lines=model_specific_config_lines,
        report_title="FORTRESS Pipeline Benchmark Report"
    )
    save_markdown_report(report_content, output_report_file, logger)
        
    logger.info(f"Benchmark run {suite_run_id} completed in {duration_seconds:.2f} seconds.")
    if duration_seconds > 0 and len(individual_results) > 0 :
        logger.info(f"Throughput: {len(individual_results) / duration_seconds:.2f} prompts/second")
    else:
        logger.info("Throughput: N/A (no results or zero duration)")
    logger.info(f"Accuracy: {final_metrics.accuracy:.4f}, F1 (Unsafe): {final_metrics.f1_unsafe if final_metrics.f1_unsafe is not None else 'N/A'}")

if __name__ == "__main__":
    main()
