                      
import argparse
import logging
import os
import sys
import time
from datetime import datetime
from typing import List, Optional, Tuple

import numpy as np
import xgboost as xgb
from pydantic import BaseModel, Field

                                  
try:
    from openai import OpenAI
except ImportError:
    print("OpenAI library not found. Please install it: pip install openai")
    sys.exit(1)

import asyncio
import httpx

                                                                    
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

                                                    
from scripts.utils.benchmark_utils import (
    setup_logger, BaseBenchmarkSingleResult, CommonBenchmarkMetrics,
    add_common_benchmark_args, create_run_id, prepare_output_paths,
    map_true_label_to_str, calculate_common_metrics,
    generate_markdown_report_parts, save_json_output, save_markdown_report,
    get_progress_bar, load_and_filter_prompts, InputPromptRecord
)
from fortress.data_management.data_loader import load_prompts_from_csv
from fortress.common.data_models import DecisionLabel as FortressDecisionLabel
from fortress.common.constants import LABEL_SAFE, LABEL_UNSAFE, SPLIT_BENCHMARK

                                                                 

class XGBoostBenchmarkSingleResult(BaseBenchmarkSingleResult):
    prediction_probability: Optional[float] = Field(None, description="The model's confidence score for the predicted label.")

class XGBoostBenchmarkSuiteResults(BaseModel):
    suite_run_id: str
    timestamp_start: str
    timestamp_end: str
    duration_seconds: float
    input_csv_files: List[str]
    output_results_file: str
    output_report_file: str
    xgboost_model_path: str
    openai_embedding_model_id: str                             
    metrics: CommonBenchmarkMetrics
    individual_results: List[XGBoostBenchmarkSingleResult]

                                                          

class EmbeddingGenerator:
    """Handles interaction with the OpenAI Embeddings API."""
    def __init__(self, model_id: str, logger: logging.Logger):
        self.model_id = model_id
        self.logger = logger
        api_key = os.environ.get("OPENAI_API_KEY")
        if not api_key:
            self.logger.critical("OPENAI_API_KEY environment variable not set.")
            raise ValueError("OPENAI_API_KEY is not set. Please set it as an environment variable.")
        
        self.client = OpenAI(api_key=api_key)
        self.logger.info(f"OpenAI embedding client initialized for model: {self.model_id}")

    def get_embedding(self, text: str) -> Optional[List[float]]:
        """
        Generates an embedding for a given text string.

        Returns:
            A list of floats representing the embedding, or None on failure.
        """
        try:
                                                                             
            text_to_embed = text.replace("\n", " ")
            response = self.client.embeddings.create(input=[text_to_embed], model=self.model_id)
            return response.data[0].embedding
        except Exception as e:
            self.logger.error(f"Failed to get embedding for text '{text[:70]}...': {e}")
            return None


class AsyncEmbeddingGenerator:
    """Handles async interaction with the OpenAI Embeddings API with rate limiting."""
    def __init__(self, model_id: str, logger: logging.Logger, max_rps: int = 15):
        self.model_id = model_id
        self.logger = logger
        self.api_key = os.environ.get("OPENAI_API_KEY")
        if not self.api_key:
            self.logger.critical("OPENAI_API_KEY environment variable not set.")
            raise ValueError("OPENAI_API_KEY is not set. Please set it as an environment variable.")
        self.endpoint = "https://api.openai.com/v1/embeddings"
        self.semaphore = asyncio.Semaphore(max_rps)
        self.client = httpx.AsyncClient(timeout=30.0)

    async def get_embedding(self, text: str) -> Optional[List[float]]:
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        data = {
            "input": [text.replace("\n", " ")],
            "model": self.model_id
        }
        async with self.semaphore:
            try:
                response = await self.client.post(self.endpoint, headers=headers, json=data)
                response.raise_for_status()
                result = response.json()
                return result["data"][0]["embedding"]
            except Exception as e:
                self.logger.error(f"Failed to get embedding for text '{text[:70]}...': {e}")
                return None

    async def aclose(self):
        await self.client.aclose()

                                                          

class XGBoostClassifier:
    def __init__(self, model_path: str, logger: logging.Logger):
        self.model_path = model_path
        self.logger = logger
        self.model = xgb.XGBClassifier()
        
        self.logger.info(f"Loading XGBoost model from: {model_path}")
        if not os.path.exists(model_path):
            self.logger.critical(f"Model file not found at {model_path}")
            raise FileNotFoundError(f"Model file not found at {model_path}")
        
        self.model.load_model(model_path)
        self.logger.info("XGBoost model loaded successfully.")

    def classify_embedding(self, embedding: List[float]) -> Tuple[str, float]:
        """Classifies a single embedding vector."""
        try:
            embedding_np = np.array(embedding).reshape(1, -1)
            prediction_numeric = self.model.predict(embedding_np)[0]
            probabilities = self.model.predict_proba(embedding_np)[0]
            confidence = probabilities[prediction_numeric]
            predicted_label_str = FortressDecisionLabel.SAFE.value if int(prediction_numeric) == LABEL_SAFE else FortressDecisionLabel.UNSAFE.value
            return predicted_label_str, float(confidence)
        except Exception as e:
            self.logger.error(f"Error during XGBoost inference: {e}", exc_info=True)
            return FortressDecisionLabel.ERROR.value, 0.0

                                          

def sync_classify_embedding(classifier, embedding):
    return classifier.classify_embedding(embedding)

async def async_main(args, logger):
    timestamp_start_dt = datetime.now()
    suite_run_id = create_run_id(base_name="ayub_oai_xgb_guard", prefix=args.run_name_prefix)
    results_dir = os.path.join(args.output_dir, "results_data")
    reports_dir = os.path.join(args.output_dir, "reports")
    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(reports_dir, exist_ok=True)
    output_results_file = os.path.join(results_dir, f"{suite_run_id}_results.json")
    output_report_file = os.path.join(reports_dir, f"{suite_run_id}_report.md")
    logger.info(f"Starting XGBoost Benchmark Suite: {suite_run_id}")
    logger.info(f"Input CSVs: {args.input_csvs}")
    logger.info(f"Outputting results to: {output_results_file}")
    logger.info(f"Outputting report to: {output_report_file}")
    logger.info(f"Using XGBoost model: {args.model_path}")
    logger.info(f"Using OpenAI embedding model: {args.embedding_model_id}")

    try:
        classifier = XGBoostClassifier(model_path=args.model_path, logger=logger)
        embedding_generator = AsyncEmbeddingGenerator(model_id=args.embedding_model_id, logger=logger, max_rps=15)
    except Exception as e:
        logger.critical(f"Failed to initialize components: {e}", exc_info=True)
        sys.exit(1)

    all_prompts: List[InputPromptRecord] = load_and_filter_prompts(
        csv_file_paths=args.input_csvs,
        split_to_filter=SPLIT_BENCHMARK,
        load_prompts_func=load_prompts_from_csv,
        logger=logger
    )
    if not all_prompts:
        logger.warning("No benchmark prompts loaded after filtering. Exiting.")
        await embedding_generator.aclose()
        sys.exit(0)

    individual_results: List[XGBoostBenchmarkSingleResult] = []
    desc = "Generating embeddings and classifying"

    async def process_prompt(prompt_record):
        true_label_str = map_true_label_to_str(
            prompt_record.label, logger,
            safe_numeric_val=LABEL_SAFE, unsafe_numeric_val=LABEL_UNSAFE
        )
        start_time = time.perf_counter()
        embedding = await embedding_generator.get_embedding(prompt_record.original_prompt)
        error_info_str: Optional[str] = None
        predicted_label_str = FortressDecisionLabel.ERROR.value
        confidence = 0.0
        if embedding is None:
            error_info_str = "Failed to generate OpenAI embedding."
        else:
            predicted_label_str, confidence = sync_classify_embedding(classifier, embedding)
            if predicted_label_str == FortressDecisionLabel.ERROR.value:
                error_info_str = "XGBoost inference failed."
        end_time = time.perf_counter()
        processing_time_ms = (end_time - start_time) * 1000
        single_result = XGBoostBenchmarkSingleResult(
            prompt_id=str(prompt_record.prompt_id) if prompt_record.prompt_id is not None else f"csv_{prompt_record.csv_row_number}",
            original_prompt=prompt_record.original_prompt,
            true_label=true_label_str,
            predicted_label=predicted_label_str,
            prediction_probability=confidence,
            processing_time_ms=processing_time_ms,
            source_file_input=prompt_record.source_file,
            prompt_category_input=prompt_record.prompt_category,
            prompt_style_input=prompt_record.prompt_style,
            error_info=error_info_str
        )
        return single_result

                                                            
    tasks = [process_prompt(pr) for pr in all_prompts]
    for coro in get_progress_bar(asyncio.as_completed(tasks), desc=desc, logger=logger):
        result = await coro
        individual_results.append(result)

    await embedding_generator.aclose()

    overall_metrics = calculate_common_metrics(
        individual_results, logger,
        positive_label=FortressDecisionLabel.UNSAFE.value,
        negative_label=FortressDecisionLabel.SAFE.value,
        error_label=FortressDecisionLabel.ERROR.value
    )
    timestamp_end_dt = datetime.now()
    duration_seconds = (timestamp_end_dt - timestamp_start_dt).total_seconds()
    suite_results = XGBoostBenchmarkSuiteResults(
        suite_run_id=suite_run_id,
        timestamp_start=timestamp_start_dt.isoformat(),
        timestamp_end=timestamp_end_dt.isoformat(),
        duration_seconds=duration_seconds,
        input_csv_files=args.input_csvs,
        output_results_file=os.path.abspath(output_results_file),
        output_report_file=os.path.abspath(output_report_file),
        xgboost_model_path=args.model_path,
        openai_embedding_model_id=args.embedding_model_id,
        metrics=overall_metrics,
        individual_results=individual_results,
    )
    save_json_output(suite_results, output_results_file, logger)
    model_specific_desc = (
        f"**XGBoost Model**: `{args.model_path}`\n\n"
        f"**OpenAI Embedding Model**: `{args.embedding_model_id}`"
    )
    report_content = generate_markdown_report_parts(
        suite_results=suite_results,
        model_specific_desc=model_specific_desc,
        report_title="XGBoost Classifier Benchmark Report (Live Embeddings)"
    )
    save_markdown_report(report_content, output_report_file, logger)
    logger.info("XGBoost Benchmark Suite finished.")
    logger.info(f"Accuracy: {overall_metrics.accuracy:.4f}, F1 (Unsafe): {overall_metrics.f1_unsafe if overall_metrics.f1_unsafe is not None else 'N/A'}")

def main():
    import sys
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
    from config.constants import BASE_PROJECT_DIR

    parser = argparse.ArgumentParser(description="Run XGBoost Classifier Benchmark Suite on live prompts.")
    parser = add_common_benchmark_args(parser)
    parser.add_argument(
        "--model-path",
        type=str,
        default=str(BASE_PROJECT_DIR / "data/08_models/xgboost_openai_classifier.json"),
        help="Path to the trained XGBoost model file."
    )
    parser.add_argument(
        "--embedding-model-id",
        type=str,
        default="text-embedding-3-small",
        help="OpenAI model ID to use for generating embeddings."
    )
    args = parser.parse_args()
    if not args.output_dir:
        args.output_dir = "benchmarks"
    logger = setup_logger(__name__, level_str=args.log_level)
    asyncio.run(async_main(args, logger))

if __name__ == "__main__":
    main()