import argparse
import json
import logging
import os
import sys
import shutil
import random
import time
import subprocess
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple

                                                                    
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

                                          
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import chromadb

from scripts.utils.benchmark_utils import setup_logger
from config.constants import (BASE_PROJECT_DIR)

                                  

                                                             
BENCHMARK_SCRIPT_PATH = str(BASE_PROJECT_DIR / "scripts" / "run_fortress_benchmark.py")
print(f"Using benchmark script path: {BENCHMARK_SCRIPT_PATH}")

DB_PATH = str(BASE_PROJECT_DIR / "data/07_vector_db/gemma3_1b_exp_scale_experiment")

                                                           

TARGET_BENCHMARKS = {
    "aegis_v2": str(BASE_PROJECT_DIR / "data/05_stitched/aegis_v2_english.csv"),
    "fortress_dataset": str(BASE_PROJECT_DIR / "data/05_stitched/fortress_dataset_english.csv"),
    "jailbreakbench": str(BASE_PROJECT_DIR / "data/05_stitched/jailbreakbench_judge_comparison_dataset.csv"),
    "xstest": str(BASE_PROJECT_DIR / "data/05_stitched/xstest_english.csv"),
}

                                   

class ChromaDBManager:
    """A minimal wrapper for ChromaDB to manage the database for the experiment."""
    def __init__(self, db_path: str, collection_name: Optional[str] = None, logger: logging.Logger = None):
        self.logger = logger or logging.getLogger(__name__)
        try:
            self.client = chromadb.PersistentClient(path=db_path)
                                                                            
            if not collection_name:
                collections = self.client.list_collections()
                if not collections:
                    raise ValueError("No collections found in the database and no collection_name provided.")
                collection_name = collections[0].name
                self.logger.info(f"Auto-detected collection name: '{collection_name}'")

            self.collection = self.client.get_collection(name=collection_name)
            self.logger.info(f"ChromaDB collection '{collection_name}' loaded successfully from '{db_path}'.")
        except Exception as e:
            self.logger.error(f"Failed to initialize ChromaDB at '{db_path}': {e}", exc_info=True)
            raise

    def get_collection_size(self) -> int:
        return self.collection.count()

    def delete_documents(self, ids: List[str], batch_size: int = 500):
        if not ids:
            return
                                                          
        for i in range(0, len(ids), batch_size):
            batch = ids[i:i+batch_size]
            self.collection.delete(ids=batch)
            self.logger.debug(f"Deleted batch of {len(batch)} documents.")

    def get_all_document_ids(self, batch_size: int = 5000) -> List[str]:
        """Retrieves all document IDs from the collection."""
        all_ids = []
        offset = 0
        self.logger.info("Retrieving all document IDs from the collection...")
        while True:
            results = self.collection.get(limit=batch_size, offset=offset, include=[])
            if not results or not results['ids']:
                break
            batch_ids = results['ids']
            all_ids.extend(batch_ids)
            if len(batch_ids) < batch_size:
                break
            offset += len(batch_ids)
        self.logger.info(f"Retrieved {len(all_ids)} total document IDs.")
        return all_ids

                                 

def run_and_parse_benchmark(csv_path: str, temp_output_dir: Path, logger: logging.Logger) -> Optional[Tuple[float, float]]:
    """
    Runs the external benchmark script as a subprocess and parses its JSON output.
    Returns a tuple of (f1_unsafe_score, average_latency_ms), or None on failure.
    """
    run_id = f"scalability_run_{datetime.now().strftime('%Y%m%d%H%M%S')}_{random.randint(1000, 9999)}"
    
    command = [
        sys.executable,                                   
        BENCHMARK_SCRIPT_PATH,
        "--input-csvs", csv_path,
        "--output-dir", str(temp_output_dir),
        "--run-id", run_id,
        "--log-level", "INFO",                              
        "--log-level", "INFO",                              
    ]

    logger.info(f"Executing command: {' '.join(command)}")
    
    try:
                                                                                             
                                           
        result = subprocess.run(command, check=True, cwd=str(BASE_PROJECT_DIR))
    except subprocess.CalledProcessError as e:
        logger.error(f"Benchmark script for {Path(csv_path).name} failed with return code {e.returncode}.")
        return None

                                                                
    results_data_dir = temp_output_dir / "results_data"
    expected_file = None
    if results_data_dir.exists():
        for f in results_data_dir.iterdir():
            if f.is_file() and f.name.startswith(run_id) and f.name.endswith("_results.json"):
                expected_file = f
                break
    if not expected_file or not expected_file.exists():
        logger.error(f"Expected output file not found in results_data: {results_data_dir} for run_id {run_id}")
        return None

    try:
        with open(expected_file, 'r') as f:
            data = json.load(f)

        metrics = data["metrics"]
        f1_unsafe = metrics.get("f1_unsafe", 0.0)
        
                                                                        
        latencies = [res.get("processing_time_ms", 0) for res in data.get("individual_results", [])]
        avg_latency = np.mean(latencies) if latencies else 0.0

        return f1_unsafe, avg_latency

    except (KeyError, json.JSONDecodeError) as e:
        logger.error(f"Failed to parse output file {expected_file}: {e}")
        return None


                           

def generate_plots(results_df: pd.DataFrame, output_dir: str, num_runs: int, logger: logging.Logger):
    """Generates and saves a TMLR-compliant figure with two subplots."""
    if results_df.empty:
        logger.warning("Results DataFrame is empty. Skipping plot generation.")
        return
        
    logger.info("Generating final plots...")
    plt.style.use('seaborn-v0_8-whitegrid')

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6.5, 7), sharex=True)                          

                                            
    sns.lineplot(data=results_df, x='db_size', y='latency_ms', hue='benchmark',
                 marker='o', ax=ax1, errorbar='sd', err_style="band")
    ax1.set_ylabel("Latency (ms/query)")
    ax1.set_title("System Latency vs. Database Size")
    ax1.legend(title="Benchmark", fontsize='small')
    ax1.grid(True)

                                             
    sns.lineplot(data=results_df, x='db_size', y='f1_unsafe', hue='benchmark',
                 marker='o', ax=ax2, errorbar='sd', err_style="band")
    ax2.set_xlabel("Database Size (Number of Documents)")
    ax2.set_ylabel("F1 Score (Unsafe Class)")
    ax2.set_title("Detection Performance vs. Database Size")
    ax2.get_legend().remove()
    ax2.grid(True)
    ax2.set_ylim(bottom=max(0, results_df['f1_unsafe'].min() - 0.1), top=1.05)

    plt.tight_layout()
    
    figure_path = Path(output_dir) / "scalability_figure.pdf"
    plt.savefig(figure_path, format='pdf', bbox_inches='tight')
    plt.close(fig)
    logger.info(f"Saved scalability plot to {figure_path}")

                                            
    caption_path = Path(output_dir) / "figure_1_caption.txt"
    caption_text = f"""
Figure 1: System performance and latency as a function of vector database size. The experiment was conducted using the FORTRESS system with its default Gemma 1B-based configuration. The x-axis represents the number of documents in the ChromaDB vector store. The top subplot shows the average query latency in milliseconds. The bottom subplot shows the F1 score for correctly identifying 'unsafe' prompts. Data points are averages over {num_runs} independent runs, where for each run, documents were randomly removed from the database in decremental steps. The shaded areas represent the standard deviation across the {num_runs} runs for each data point. Four different benchmark datasets were used to evaluate the system at each database size.
"""
    with open(caption_path, 'w') as f:
        f.write(caption_text.strip())
    logger.info(f"Saved figure caption to {caption_path}")


                                

def main():
    parser = argparse.ArgumentParser(description="Run FORTRESS Scalability Experiment.")
    parser.add_argument("--steps", type=int, default=16, help="Number of database size intervals to test.")
    parser.add_argument("--runs-per-step", type=int, default=3, help="Number of times to repeat the measurement at each size.")
    parser.add_argument("--db-path", type=str, default=DB_PATH, help="Path to the ChromaDB directory to test.")
    parser.add_argument("--output-dir", type=str, default=f"benchmarks/scalability_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}", help="Directory to save plots and results.")
    parser.add_argument("--collection-name", type=str, default=None, help="Name of the ChromaDB collection. Auto-detects if not provided.")
    parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
    
    args = parser.parse_args()

              
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    logger = setup_logger(__name__, level_str=args.log_level)
    
    logger.info("--- FORTRESS Scalability Experiment ---")
    logger.info(f"Arguments: {vars(args)}")

                        
    original_db_path = Path(args.db_path)
    if not original_db_path.exists():
        logger.error(f"Database path not found: {original_db_path}"); sys.exit(1)

    backup_db_path = original_db_path.with_name(f"{original_db_path.name}_scalability_backup")
    temp_benchmark_output_dir = output_dir / "temp_benchmark_outputs"

    logger.info(f"Backing up database from '{original_db_path}' to '{backup_db_path}'...")
    if backup_db_path.exists(): shutil.rmtree(backup_db_path)
    shutil.copytree(original_db_path, backup_db_path)
    logger.info("Backup complete.")

    all_experiment_results = []

    try:
                                 
        for run_idx in range(args.runs_per_step):
            logger.info(f"--- Starting Full Experiment Run {run_idx + 1}/{args.runs_per_step} ---")

            db_manager = ChromaDBManager(db_path=str(original_db_path), collection_name=args.collection_name, logger=logger)
            db_manager = ChromaDBManager(db_path=str(original_db_path), collection_name=args.collection_name, logger=logger)
            all_doc_ids = db_manager.get_all_document_ids()
            random.shuffle(all_doc_ids)
            initial_size = len(all_doc_ids)
            docs_to_remove_per_step = initial_size // args.steps if args.steps > 0 else 0

            for step_idx in range(args.steps + 1):
                current_db_size = db_manager.get_collection_size()
                logger.info(f"Run {run_idx + 1}, Step {step_idx}: DB size = {current_db_size}")

                if temp_benchmark_output_dir.exists(): shutil.rmtree(temp_benchmark_output_dir)
                temp_benchmark_output_dir.mkdir()

                for benchmark_name, csv_path in TARGET_BENCHMARKS.items():
                    results = run_and_parse_benchmark(csv_path, temp_benchmark_output_dir, logger)
                    if results:
                        f1, latency = results
                        all_experiment_results.append({
                            "run": run_idx + 1, "step": step_idx, "db_size": current_db_size,
                            "benchmark": benchmark_name, "f1_unsafe": f1, "latency_ms": latency,
                        })
                
                                                                        
                pd.DataFrame(all_experiment_results).to_csv(output_dir / "scalability_results_live.csv", index=False)

                if step_idx < args.steps:
                    offset = step_idx * docs_to_remove_per_step
                    ids_to_delete = all_doc_ids[offset : offset + docs_to_remove_per_step]
                    if ids_to_delete:
                        logger.info(f"Removing {len(ids_to_delete)} documents for next step...")
                        db_manager.delete_documents(ids_to_delete)
            
            logger.info(f"--- Finished Full Experiment Run {run_idx + 1} ---")

                                      
        logger.info("Experiment loops complete. Finalizing results...")
        results_df = pd.DataFrame(all_experiment_results)
        final_csv_path = output_dir / "scalability_results_final.csv"
        results_df.to_csv(final_csv_path, index=False)
        logger.info(f"Final results saved to {final_csv_path}")

        generate_plots(results_df, str(output_dir), args.runs_per_step, logger)

    except Exception as e:
        logger.critical(f"An unhandled exception occurred: {e}", exc_info=True)
    finally:
                    
        logger.info(f"Restoring original database from backup '{backup_db_path}'...")
        if backup_db_path.exists():
            if original_db_path.exists(): shutil.rmtree(original_db_path)
            shutil.move(str(backup_db_path), str(original_db_path))
            logger.info("Database restored successfully.")
        else:
            logger.error("Backup path not found! Could not restore original database.")
        
        if temp_benchmark_output_dir.exists(): shutil.rmtree(temp_benchmark_output_dir)

    logger.info("--- Experiment Finished ---")

if __name__ == "__main__":
    main()