                                

import argparse
import json
import logging
import os
import sys
import shutil
import random
import subprocess
import time
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import pandas as pd
import yaml
from tqdm import tqdm
from sklearn.model_selection import KFold

                            
                                                                   
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.constants import BASE_PROJECT_DIR

BENCHMARK_SCRIPT_PATH = str(BASE_PROJECT_DIR / "scripts" / "run_fortress_benchmark.py")

FORTRESS_DATASET_CSV = str(BASE_PROJECT_DIR / "data/05_stitched/fortress_dataset_english.csv")

BASE_DB_PATH = str(BASE_PROJECT_DIR / "data/07_vector_db/gemma3_1b_base")

SETTINGS_YAML_PATH = str(BASE_PROJECT_DIR / "config/settings.yaml")


                                                                      
UNSAFE_CATEGORIES = [
    "s1_violent_crimes", "s2_non_violent_crimes", "s3_sex_related_crimes",
    "s4_child_sexual_exploitation", "s5_defamation", "s6_specialized_advice",
    "s7_privacy", "s8_intellectual_property", "s9_indiscriminate_weapons",
    "s10_hate", "s11_suicide_self_harm", "s12_sexual_content", "s13_elections"
]

                           

def setup_logger(log_dir: Path, level_str: str = "INFO") -> logging.Logger:
    """Sets up a logger to file and console."""
    logger = logging.getLogger("LOCO_Experiment")
    log_level = getattr(logging, level_str.upper(), logging.INFO)
    logger.setLevel(log_level)
    
                                                 
    if not logger.handlers:
                         
        ch = logging.StreamHandler()
        ch.setLevel(log_level)
        ch_formatter = logging.Formatter('%(levelname)s - %(message)s')
        ch.setFormatter(ch_formatter)
        logger.addHandler(ch)

                      
        fh = logging.FileHandler(log_dir / "experiment.log")
        fh.setLevel(logging.DEBUG)
        fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(fh_formatter)
        logger.addHandler(fh)
        
    return logger

                                   
                                                          

class ChromaDBManager:
    """A minimal wrapper for ChromaDB to manage the database for the experiment."""
    def __init__(self, db_path: str, logger: logging.Logger):
        self.logger = logger
                                                                   
        try:
            import chromadb
        except ImportError:
            self.logger.error("chromadb is not installed. Please run 'pip install chromadb'")
            sys.exit(1)
            
        try:
            self.client = chromadb.PersistentClient(path=db_path)
            collections = self.client.list_collections()
            if not collections:
                raise ValueError("No collections found in the database.")
            self.collection_name = collections[0].name
            self.collection = self.client.get_collection(name=self.collection_name)
            self.logger.debug(f"ChromaDB collection '{self.collection_name}' loaded from '{db_path}'.")
        except Exception as e:
            self.logger.error(f"Failed to initialize ChromaDB at '{db_path}': {e}", exc_info=True)
            raise

    def get_all_documents_with_text(self, batch_size: int = 2000) -> Dict[str, str]:
        """Retrieves all documents and maps their text content to their ID."""
        text_to_id_map = {}
        offset = 0
        self.logger.debug("Retrieving all documents to map text to ID...")
        while True:
            results = self.collection.get(limit=batch_size, offset=offset, include=['documents'])
            if not results or not results['ids']:
                break
            for doc_id, document_text in zip(results['ids'], results['documents']):
                if document_text:
                    text_to_id_map[document_text] = doc_id
            
            if len(results['ids']) < batch_size:
                break
            offset += len(results['ids'])
        self.logger.debug(f"Mapped {len(text_to_id_map)} documents.")
        return text_to_id_map

    def delete_documents_by_id(self, ids: List[str], batch_size: int = 500):
        if not ids:
            return
        for i in range(0, len(ids), batch_size):
            batch = ids[i:i + batch_size]
            self.collection.delete(ids=batch)

                               

def run_and_parse_benchmark(csv_path: str, temp_output_dir: Path, logger: logging.Logger) -> Optional[Dict[str, Any]]:
    """Runs the external benchmark script and parses its JSON output."""
    run_id = f"loco_run_{Path(csv_path).stem}_{random.randint(1000, 9999)}"
    
    command = [
        sys.executable,
        BENCHMARK_SCRIPT_PATH,
        "--input-csvs", csv_path,
        "--output-dir", str(temp_output_dir),
        "--run-id", run_id,
        "--log-level", "WARNING",                               
    ]

    logger.debug(f"Executing command: {' '.join(command)}")
    
    try:
                                                                             
        subprocess.run(command, check=True, cwd=str(BASE_PROJECT_DIR), capture_output=True, text=True)
    except subprocess.CalledProcessError as e:
        logger.error(f"Benchmark script for {Path(csv_path).name} failed.")
        logger.error(f"STDOUT: {e.stdout}")
        logger.error(f"STDERR: {e.stderr}")
        return None

    results_data_dir = temp_output_dir / "results_data"
    try:
        json_file = next(results_data_dir.glob(f"{run_id}*.json"))
        with open(json_file, 'r') as f:
            data = json.load(f)
        return data["metrics"]
    except (StopIteration, KeyError, json.JSONDecodeError) as e:
        logger.error(f"Failed to find or parse output JSON for run {run_id}: {e}")
        return None

def modify_settings_yaml(temp_db_path: str, logger: logging.Logger):
    """Temporarily modifies the settings.yaml to point to the temp database."""
    try:
        with open(SETTINGS_YAML_PATH, 'r') as f:
            settings = yaml.safe_load(f)
        
        settings['vector_database']['path'] = temp_db_path
        
        with open(SETTINGS_YAML_PATH, 'w') as f:
            yaml.dump(settings, f)
        logger.debug(f"settings.yaml modified to use DB: {temp_db_path}")
    except Exception as e:
        logger.error(f"Failed to modify settings.yaml: {e}", exc_info=True)
        raise

                     

def main():
    parser = argparse.ArgumentParser(description="Run FORTRESS Leave-One-Category-Out (LOCO) Experiment.")
    parser.add_argument("--k-folds", type=int, default=5, help="Number of folds for cross-validation.")
    parser.add_argument("--output-dir", type=str, default=f"benchmarks/loco_experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}", help="Directory to save results.")
    parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
    
    args = parser.parse_args()

              
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    logger = setup_logger(output_dir, level_str=args.log_level)
    
    logger.info("--- FORTRESS LOCO Experiment ---")
    logger.info(f"Running with {args.k_folds}-fold cross-validation.")
    logger.info(f"Results will be saved in: {output_dir.resolve()}")

                          
    settings_yaml_path = Path(SETTINGS_YAML_PATH)
    backup_yaml_path = settings_yaml_path.with_suffix('.yaml.loco_backup')
    shutil.copy2(settings_yaml_path, backup_yaml_path)
    logger.info(f"Backed up settings.yaml to {backup_yaml_path}")
    
    temp_experiment_dir = output_dir / "temp_assets"
    
    all_results = []
    
    try:
                                  
        logger.info(f"Loading master dataset from {FORTRESS_DATASET_CSV}")
        master_df = pd.read_csv(FORTRESS_DATASET_CSV)
        master_df = master_df.dropna(subset=['original_prompt', 'prompt_category'])
        
                            
        kf = KFold(n_splits=args.k_folds, shuffle=True, random_state=42)
        
                             
        pbar_folds = tqdm(enumerate(kf.split(master_df)), total=args.k_folds, desc="K-Folds")
        for fold_idx, (train_index, test_index) in pbar_folds:
            pbar_folds.set_description(f"K-Fold {fold_idx + 1}/{args.k_folds}")
            
            training_pool_df = master_df.iloc[train_index]
            test_fold_df = master_df.iloc[test_index]
            
                                     
            pbar_cats = tqdm(UNSAFE_CATEGORIES, desc="Categories", leave=False)
            for category_to_hold_out in pbar_cats:
                pbar_cats.set_description(f"Holding out: {category_to_hold_out}")

                run_name = f"fold{fold_idx}_hout_{category_to_hold_out}"
                
                                               
                temp_db_path = temp_experiment_dir / f"db_{run_name}"
                if temp_db_path.exists(): shutil.rmtree(temp_db_path)
                shutil.copytree(BASE_DB_PATH, temp_db_path)

                                                                                              
                prompts_to_remove = training_pool_df[
                    training_pool_df['prompt_category'] == category_to_hold_out
                ]['original_prompt'].tolist()
                
                if prompts_to_remove:
                    try:
                        db_manager = ChromaDBManager(str(temp_db_path), logger)
                        text_to_id_map = db_manager.get_all_documents_with_text()
                        ids_to_delete = [text_to_id_map[p] for p in prompts_to_remove if p in text_to_id_map]
                        db_manager.delete_documents_by_id(ids_to_delete)
                        logger.debug(f"Removed {len(ids_to_delete)} '{category_to_hold_out}' docs from temp DB.")
                    except Exception as e:
                        logger.error(f"Failed DB operation for {run_name}: {e}")
                        continue

                                                                       
                benchmark_df = test_fold_df[test_fold_df['prompt_category'] == category_to_hold_out]
                if benchmark_df.empty:
                    logger.debug(f"No samples for '{category_to_hold_out}' in test fold {fold_idx}. Skipping.")
                    continue
                
                temp_csv_path = str(temp_experiment_dir / f"benchmark_{run_name}.csv")
                benchmark_df.to_csv(temp_csv_path, index=False)
                
                                                           
                modify_settings_yaml(str(temp_db_path), logger)
                
                benchmark_metrics = run_and_parse_benchmark(temp_csv_path, temp_experiment_dir, logger)
                
                if benchmark_metrics:
                    all_results.append({
                        "fold": fold_idx,
                        "held_out_category": category_to_hold_out,
                        "f1_unsafe": benchmark_metrics.get("f1_unsafe", 0.0),
                        "recall_unsafe": benchmark_metrics.get("recall_unsafe", 0.0),
                        "precision_unsafe": benchmark_metrics.get("precision_unsafe", 0.0),
                        "num_samples": benchmark_df.shape[0]
                    })
        
                                     
        logger.info("Aggregating and saving results...")
        results_df = pd.DataFrame(all_results)
        results_df.to_csv(output_dir / "loco_raw_results.csv", index=False)
        
                                        
        summary_df = results_df.groupby('held_out_category').agg(
            avg_f1=('f1_unsafe', 'mean'),
            avg_recall=('recall_unsafe', 'mean'),
            avg_precision=('precision_unsafe', 'mean'),
            std_f1=('f1_unsafe', 'std'),
            total_samples=('num_samples', 'sum')
        ).reset_index()
        
                             
        overall_avg = summary_df.mean(numeric_only=True)
        overall_avg['held_out_category'] = '--- OVERALL AVERAGE ---'
        summary_df = pd.concat([summary_df, pd.DataFrame([overall_avg])], ignore_index=True)

        summary_df.to_csv(output_dir / "loco_summary_results.csv", index=False)
        
        logger.info("Final Summary:")
        print("\n" + summary_df.to_string())

    except Exception as e:
        logger.critical(f"An unhandled exception occurred: {e}", exc_info=True)
    finally:
                    
        logger.info("Cleaning up...")
        if backup_yaml_path.exists():
            shutil.move(str(backup_yaml_path), SETTINGS_YAML_PATH)
            logger.info("Restored original settings.yaml.")
        if temp_experiment_dir.exists():
            shutil.rmtree(temp_experiment_dir)
            logger.info("Removed temporary experiment directory.")
        
    logger.info("--- LOCO Experiment Finished ---")

if __name__ == "__main__":
    main()