import argparse
import json
import pandas as pd
import numpy as np
from skopt import gp_minimize
from skopt.space import Real, Categorical
from skopt.utils import use_named_args
import logging
import sys
import os
from tqdm import tqdm

                                 
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.constants import SAFE_CATEGORIES, UNSAFE_CATEGORIES

                                    
try:
    from fortress.detection_pipeline.secondary_analyzer import PerplexityAnalyzerEngine
    from fortress.core.embedding_model import EmbeddingModel
    from fortress.config import get_config
except ImportError as e:
    print(f"Error importing fortress modules: {e}")
    print("Please ensure that the script is run from the 'scripts' directory or that PYTHONPATH is set up correctly.")
    sys.exit(1)

                                     
ALL_CATEGORIES = SAFE_CATEGORIES + UNSAFE_CATEGORIES

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)

                                                          
search_space = [
    Real(-10.0, -1.0, name='adversarial_token_uniform_log_prob'),
    Real(0.1, 5.0, name='lambda_smoothness_penalty'),
    Real(-5.0, 5.0, name='mu_adversarial_token_prior'),
    Categorical([True, False], name='apply_first_token_neutral_bias')
]

                                                               
_current_optimization_data = None

                        

def load_and_filter_data(csv_path: str) -> pd.DataFrame:
    """Loads and filters the dataset from a CSV file."""
    logger.info(f"Loading data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        logger.error(f"Input CSV file not found at {csv_path}")
        sys.exit(1)
    
    logger.info(f"Initial dataset size: {len(df)} rows.")
    df_filtered = df[df['split'] == 'database'].copy()
    logger.info(f"Dataset size after filtering for 'split == \\'database\\'': {len(df_filtered)} rows.")
    
    if df_filtered.empty:
        logger.error("No data found for 'split == \\'database\\''. Exiting.")
        sys.exit(1)
        
    df_filtered.dropna(subset=['label', 'prompt_category'], inplace=True)
    logger.info(f"Dataset size after dropping rows with missing label/prompt_category: {len(df_filtered)} rows.")
    return df_filtered

def get_token_log_probs(df: pd.DataFrame, embedding_model_instance: EmbeddingModel) -> pd.DataFrame:
    """Generates token source log probabilities for each prompt in the DataFrame."""
    logger.info(f"Generating token_source_log_probabilities for {len(df)} prompts...")
    log_probs_column_data = []
    
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Generating Log Probs"):
        prompt_text = row.get('original_prompt')
        if pd.isna(prompt_text) or not str(prompt_text).strip():
            log_probs_column_data.append(None)
            continue
        try:
            current_log_probs = embedding_model_instance.get_token_source_log_probabilities(text=str(prompt_text))
            if current_log_probs and isinstance(current_log_probs, list):
                log_probs_column_data.append(current_log_probs)
            else:
                log_probs_column_data.append(None)
        except Exception as e:
            logger.error(f"Error getting log_probs for prompt '{str(prompt_text)[:50]}...': {e}", exc_info=True)
            log_probs_column_data.append(None)
            
    df = df.assign(token_source_log_probs=log_probs_column_data)
    df_processed = df.dropna(subset=['token_source_log_probs'])
    
    logger.info(f"Successfully generated log_probs for {len(df_processed)} prompts out of {len(df)}.")
    return df_processed

@use_named_args(search_space)
def objective_function(adversarial_token_uniform_log_prob: float,
                       lambda_smoothness_penalty: float,
                       mu_adversarial_token_prior: float,
                       apply_first_token_neutral_bias: bool) -> float:
    """
    Objective function for Bayesian optimization. Calculates Mean Squared Error (MSE)
    against a target probability (0.0 for safe, 1.0 for unsafe).
    """
    global _current_optimization_data
    
    if _current_optimization_data is None:
        logger.error("Global optimization data not set for objective function.")
        return 1.0

    engine = PerplexityAnalyzerEngine(
        adversarial_token_uniform_log_prob=adversarial_token_uniform_log_prob,
        lambda_smoothness_penalty=lambda_smoothness_penalty,
        mu_adversarial_token_prior=mu_adversarial_token_prior,
        apply_first_token_neutral_bias=apply_first_token_neutral_bias
    )
    
    total_squared_error = 0
    num_processed = 0

    for _, row in _current_optimization_data.iterrows():
        log_probs = row['token_source_log_probs']
        category_name = row['prompt_category']
        
        _, sentence_prob = engine.calculate_adversarial_probabilities(log_probs)
        if not np.isfinite(sentence_prob):
            sentence_prob = 0.5                                           
        
        target = 1.0 if category_name in UNSAFE_CATEGORIES else 0.0
            
        squared_error = (sentence_prob - target)**2
        total_squared_error += squared_error
        num_processed += 1

    if num_processed == 0:
        logger.warning("No valid prompts processed in objective function. Returning worst score (1.0).")
        return 1.0

    mse = total_squared_error / num_processed
    return mse if np.isfinite(mse) else 1.0

def run_optimization(df_subset: pd.DataFrame, n_calls: int, n_initial_points: int) -> dict:
    """
    Runs the gp_minimize optimization process on a given subset of data.
    """
    global _current_optimization_data
    
    if df_subset.empty:
        return {"error": "No data provided for optimization."}

    _current_optimization_data = df_subset

    try:
                                                         
        effective_n_calls = n_calls
        if len(df_subset) < 2 * n_initial_points:
             effective_n_calls = max(n_initial_points + 5, int(n_calls / 2))
             logger.info(f"Reduced n_calls to {effective_n_calls} for small sample size ({len(df_subset)}).")

        result = gp_minimize(
            func=objective_function,
            dimensions=search_space,
            n_calls=effective_n_calls,
            n_initial_points=n_initial_points,
            random_state=42,
            verbose=False
        )
    except Exception as e:
        logger.error(f"Error during gp_minimize: {e}", exc_info=True)
        _current_optimization_data = None
        return {"error": f"gp_minimize failed: {e}"}

    best_params = {param.name: value for param, value in zip(search_space, result.x)}
    achieved_mse = result.fun
    
    logger.info(f"Optimization finished. Best Params: {best_params} (Achieved MSE: {achieved_mse:.4f})")

    _current_optimization_data = None                         
    
    return {**best_params, "achieved_mse": achieved_mse}

                        

def main():
    parser = argparse.ArgumentParser(description="Optimize PerplexityAnalyzerEngine parameters.")
    parser.add_argument("--input_csv", required=True, help="Path to the input CSV file.")
    parser.add_argument("--output_json", required=True, help="Path to the output JSON file for optimized parameters.")
    parser.add_argument("--n_calls", type=int, default=30, help="Number of calls for gp_minimize.")
    parser.add_argument("--n_initial_points", type=int, default=10, help="Number of initial points for gp_minimize.")
    parser.add_argument(
        "--mode",
        type=str,
        choices=['global', 'per_category'],
        default='per_category',
        help="Optimization mode: 'global' for one set of params, 'per_category' for params specific to each category."
    )
    args = parser.parse_args()

                                                 
    df_database = load_and_filter_data(args.input_csv)

    logger.info("Initializing EmbeddingModel...")
    try:
        embedding_model = EmbeddingModel()
        embedding_model.get_token_source_log_probabilities("test")
        logger.info("EmbeddingModel initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize EmbeddingModel: {e}", exc_info=True)
        sys.exit(1)

                                                           
    df_with_log_probs = get_token_log_probs(df_database, embedding_model)
    if df_with_log_probs.empty:
        logger.error("No valid prompts with log probabilities were generated. Cannot optimize. Exiting.")
        sys.exit(1)

                                               
    all_optimized_results = {}
    logger.info(f"Starting optimization in '{args.mode}' mode.")

    if args.mode == 'global':
        logger.info(f"Running GLOBAL optimization with {len(df_with_log_probs)} samples.")
        global_results = run_optimization(df_with_log_probs, args.n_calls, args.n_initial_points)

        if "error" in global_results:
            logger.error(f"Global optimization failed: {global_results['error']}")
            sys.exit(1)
        
                                                                                  
        global_results["sentence_adversarial_probability_threshold"] = 0.5
        
        logger.info("Assembling final JSON with globally-optimized parameters...")
        for category in ALL_CATEGORIES:
            all_optimized_results[category] = global_results.copy()
            if category not in df_database['prompt_category'].unique():
                 all_optimized_results[category]["info"] = "No data for this category; using global params."

    elif args.mode == 'per_category':
        for category in tqdm(ALL_CATEGORIES, desc="Optimizing categories"):
            logger.info(f"--- Processing category: {category} ---")
            df_category_subset = df_with_log_probs[df_with_log_probs['prompt_category'] == category].copy()

            if df_category_subset.empty:
                logger.warning(f"No data for category: {category}. Skipping.")
                all_optimized_results[category] = {
                    "error": "No data for this category in the 'database' split.",
                    "sentence_adversarial_probability_threshold": 0.5
                }
                continue

            category_results = run_optimization(df_category_subset, args.n_calls, args.n_initial_points)
            category_results["sentence_adversarial_probability_threshold"] = 0.5
            all_optimized_results[category] = category_results

                                               
    output_data = {"category_specific_settings": all_optimized_results}
    try:
        config_data = get_config()
        pa_config = getattr(config_data, 'perplexity_analyzer', {})
        dp_config = getattr(config_data, 'detection_pipeline', {})
        output_data["default_engine_settings"] = {
            "adversarial_token_uniform_log_prob": getattr(pa_config, "adversarial_token_uniform_log_prob", -5.0),
            "lambda_smoothness_penalty": getattr(pa_config, "lambda_smoothness_penalty", 2.5),
            "mu_adversarial_token_prior": getattr(pa_config, "mu_adversarial_token_prior", -2.0),
            "apply_first_token_neutral_bias": getattr(pa_config, "apply_first_token_neutral_bias", False),
            "sentence_adversarial_probability_threshold": getattr(dp_config, "sentence_perplexity_unsafe_threshold", 0.8)
        }
    except Exception as e:
        logger.warning(f"Could not load default settings from config: {e}. Using hardcoded defaults.")
        output_data["default_engine_settings"] = {
            "adversarial_token_uniform_log_prob": -5.0,
            "lambda_smoothness_penalty": 2.5,
            "mu_adversarial_token_prior": -2.0,
            "apply_first_token_neutral_bias": False,
            "sentence_adversarial_probability_threshold": 0.8,
            "error": "Failed to load from settings.yaml, used hardcoded."
        }

    logger.info(f"Saving optimized parameters to {args.output_json}")
    try:
        with open(args.output_json, 'w') as f:
            json.dump(output_data, f, indent=2)
        logger.info("Optimization complete. Results saved.")
    except IOError as e:
        logger.error(f"Failed to write output JSON file: {e}", exc_info=True)
        sys.exit(1)

if __name__ == "__main__":
    main()
