import argparse
import json
import logging
import os
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple

import asyncio                                    
import aiohttp                                        
from asyncio import Semaphore                                 
from pydantic import BaseModel                                

                                                                    
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

                           
from fortress.common.data_models import DecisionLabel as FortressDecisionLabel
from fortress.data_management.data_loader import load_prompts_from_csv
from fortress.common.constants import SPLIT_BENCHMARK, LABEL_SAFE, LABEL_UNSAFE

                             
from scripts.utils.benchmark_utils import (
    setup_logger,
    BaseBenchmarkSingleResult,
    CommonBenchmarkMetrics,
    add_common_benchmark_args,
    create_run_id,
    prepare_output_paths,
    load_and_filter_prompts,
    map_true_label_to_str,
    calculate_common_metrics,
    generate_markdown_report_parts,
    save_json_output,
    save_markdown_report,
                                                           
    InputPromptRecord,
    DecisionLabel as UtilDecisionLabel
)
from tqdm import tqdm                   

                                               

class OpenAIModerationSingleResult(BaseBenchmarkSingleResult):
    flagged: Optional[bool] = None
    categories: Optional[Dict[str, bool]] = None
    category_scores: Optional[Dict[str, float]] = None
    openai_raw_response: Optional[Dict[str, Any]] = None

class OpenAIModerationSuiteResults(BaseModel):
    suite_run_id: str
    timestamp_start: str
    timestamp_end: str
    duration_seconds: float
    input_csv_files: List[str]
    output_results_file: str
    output_report_file: str
    openai_model_id: str
    metrics: CommonBenchmarkMetrics
    individual_results: List[OpenAIModerationSingleResult]

                                                     

def parse_openai_moderation_output(
    response_json: Dict[str, Any], 
    logger: logging.Logger
) -> Tuple[str, bool, Dict[str, bool], Dict[str, float]]:
    try:
        if not response_json.get("results") or not isinstance(response_json["results"], list) or not response_json["results"][0]:
            logger.warning(f"Malformed OpenAI Moderation API response: 'results' array is missing or empty. Full response: {response_json}")
            return FortressDecisionLabel.ERROR.value, False, {}, {}
        result = response_json["results"][0]
        flagged = result.get("flagged", False)
        categories = result.get("categories", {})
        category_scores = result.get("category_scores", {})
        predicted_label = FortressDecisionLabel.UNSAFE.value if flagged else FortressDecisionLabel.SAFE.value
        return predicted_label, flagged, categories, category_scores
    except Exception as e:
        logger.error(f"Error parsing OpenAI Moderation API response: {e}. Response: {response_json}", exc_info=True)
        return FortressDecisionLabel.ERROR.value, False, {}, {}

                                                            
class OpenAIModerator:
    def __init__(self, api_key: str, model_id: str, logger: logging.Logger,
                 max_retries: int = 3, initial_backoff_secs: float = 1.0, backoff_factor: float = 2.0):
        self.api_key = api_key
        self.model_id = model_id
        self.logger = logger
        self.api_url = "https://api.openai.com/v1/moderations"
        self.max_retries = max_retries
        self.initial_backoff_secs = initial_backoff_secs
        self.backoff_factor = backoff_factor
        
        if not self.api_key:
            self.logger.critical("OPENAI_API_KEY environment variable not set or empty.")
            raise ValueError("OpenAI API key is required.")

    async def moderate_prompt(self, prompt_text: str, session: aiohttp.ClientSession) -> Dict[str, Any]:
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        data = {
            "input": prompt_text,
            "model": self.model_id
        }
        
        current_backoff = self.initial_backoff_secs
        for attempt in range(self.max_retries + 1):                              
            try:
                async with session.post(self.api_url, headers=headers, json=data, timeout=aiohttp.ClientTimeout(total=60)) as response:
                    if response.status == 429:              
                        self.logger.warning(
                            f"Rate limit hit (status 429) for prompt '{prompt_text[:50]}...'. Attempt {attempt + 1}/{self.max_retries + 1}."
                        )
                        if attempt == self.max_retries:
                            self.logger.error(f"Max retries reached for rate limit for prompt '{prompt_text[:50]}...'.")
                            return {"error": "Rate limit exceeded after multiple retries", "results": [{"flagged": False, "categories": {}, "category_scores": {}}]}
                        
                        retry_after_str = response.headers.get("Retry-After")
                        sleep_duration = current_backoff
                        if retry_after_str:
                            try:
                                sleep_duration = int(retry_after_str)
                                self.logger.info(f"Retrying after {sleep_duration} seconds (from Retry-After header).")
                            except ValueError:
                                self.logger.warning(f"Invalid Retry-After header value: {retry_after_str}. Using exponential backoff: {current_backoff:.2f}s.")
                        else:
                             self.logger.info(f"Retrying in {current_backoff:.2f} seconds (exponential backoff).")
                        
                        await asyncio.sleep(sleep_duration)
                        if not retry_after_str or not retry_after_str.isdigit():                                                                    
                            current_backoff *= self.backoff_factor
                        continue

                    response.raise_for_status()                                                               
                    return await response.json()
                    
            except aiohttp.ClientResponseError as e:
                self.logger.error(f"HTTP error {e.status} calling OpenAI API for prompt '{prompt_text[:50]}...': {e.message}. Attempt {attempt + 1}/{self.max_retries + 1}")
                if attempt == self.max_retries or (e.status >= 400 and e.status < 500 and e.status != 429):                                                        
                    return {"error": f"HTTP error: {e.status} - {e.message}", "results": [{"flagged": False, "categories": {}, "category_scores": {}}]}
                await asyncio.sleep(current_backoff)
                current_backoff *= self.backoff_factor

            except aiohttp.ClientError as e:                                                 
                self.logger.error(f"ClientError calling OpenAI API for prompt '{prompt_text[:50]}...': {e}. Attempt {attempt + 1}/{self.max_retries + 1}")
                if attempt == self.max_retries:
                    return {"error": str(e), "results": [{"flagged": False, "categories": {}, "category_scores": {}}]}
                await asyncio.sleep(current_backoff)
                current_backoff *= self.backoff_factor
            except asyncio.TimeoutError:                               
                self.logger.error(f"Timeout calling OpenAI API for prompt '{prompt_text[:50]}...'. Attempt {attempt + 1}/{self.max_retries + 1}")
                if attempt == self.max_retries:
                     return {"error": "API call timed out after multiple retries", "results": [{"flagged": False, "categories": {}, "category_scores": {}}]}
                await asyncio.sleep(current_backoff)
                current_backoff *= self.backoff_factor
                
        self.logger.error(f"Max retries exhausted for prompt '{prompt_text[:50]}...'. Returning error.")
        return {"error": "Max retries reached after all attempts", "results": [{"flagged": False, "categories": {}, "category_scores": {}}]}

                                        
async def process_single_prompt_for_gather(
    prompt_record: InputPromptRecord,
    moderator: OpenAIModerator,
    session: aiohttp.ClientSession,
    logger: logging.Logger,
    semaphore: asyncio.Semaphore,
    idx_for_fallback_id: int 
) -> OpenAIModerationSingleResult:
    async with semaphore:                                      
        true_label_str = map_true_label_to_str(
            prompt_record.label, logger,
            safe_numeric_val=LABEL_SAFE, unsafe_numeric_val=LABEL_UNSAFE
        )
        
        start_time = time.perf_counter()
        raw_response_json = await moderator.moderate_prompt(prompt_record.original_prompt, session)
        end_time = time.perf_counter()
        processing_time_ms = (end_time - start_time) * 1000

        error_info_str: Optional[str] = None
        if "error" in raw_response_json: 
            error_info_str = f"OpenAI API call failed: {raw_response_json['error']}"
            predicted_label_str = FortressDecisionLabel.ERROR.value
            flagged_val, categories_dict, category_scores_dict = False, {}, {}
        else:
            predicted_label_str, flagged_val, categories_dict, category_scores_dict = parse_openai_moderation_output(raw_response_json, logger)
            if predicted_label_str == FortressDecisionLabel.ERROR.value and not error_info_str:
                    error_info_str = f"Could not parse OpenAI Moderation API response (first 100 chars of raw): {str(raw_response_json)[:100]}"
        
        prompt_id_val = str(prompt_record.prompt_id) if prompt_record.prompt_id is not None else\
                        (f"csvidx_{prompt_record.csv_row_number}" if hasattr(prompt_record, 'csv_row_number') and prompt_record.csv_row_number is not None 
                         else f"fallbackidx_{idx_for_fallback_id}")


        single_result = OpenAIModerationSingleResult(
            prompt_id=prompt_id_val,
            original_prompt=prompt_record.original_prompt,
            true_label=true_label_str,
            predicted_label=predicted_label_str,
            flagged=flagged_val,
            categories=categories_dict,
            category_scores=category_scores_dict,
            openai_raw_response=raw_response_json if logger.level <= logging.DEBUG else None,
            processing_time_ms=processing_time_ms,
            source_file_input=prompt_record.source_file,
            prompt_category_input=prompt_record.prompt_category,
            prompt_style_input=prompt_record.prompt_style,
            error_info=error_info_str
        )
        return single_result

                          
async def main_async_logic():
    parser = argparse.ArgumentParser(description="Run OpenAI Moderation Benchmark Suite (Async).")
    parser = add_common_benchmark_args(parser) 
    
    parser.add_argument(
        "--openai-model", type=str, default="text-moderation-latest",
        help="OpenAI Moderation model ID. (Default: text-moderation-latest)"
    )
    parser.add_argument(
        "--concurrency-limit", type=int, default=10,
        help="Maximum number of concurrent requests to OpenAI API. (Default: 10)"
    )
    parser.add_argument(
        "--max-retries", type=int, default=3,
        help="Maximum number of retries for API calls (0 means 1 attempt). (Default: 3)"
    )
    parser.add_argument(
        "--initial-backoff-secs", type=float, default=1.0,
        help="Initial backoff delay in seconds for retries. (Default: 1.0)"
    )
    parser.add_argument(
        "--backoff-factor", type=float, default=2.0,
        help="Factor by which backoff delay increases for retries. (Default: 2.0)"
    )
    args = parser.parse_args()

    logger = setup_logger(__name__, level_str=args.log_level)

    openai_api_key = os.getenv("OPENAI_API_KEY")
    if not openai_api_key:
        logger.critical("OPENAI_API_KEY environment variable not found. Please set it.")
        sys.exit(1)

    timestamp_start_dt = datetime.now()
    suite_run_id = create_run_id(base_name=args.openai_model.replace("/", "_").replace("-", "_"), 
                                 prefix=(args.run_name_prefix or "openai_mod_async"))
    
    output_results_file, output_report_file, _ = prepare_output_paths(
        args.output_dir, suite_run_id, 
        results_filename_suffix="_results.json", 
        report_filename_suffix="_report.md"
    )

    logger.info(f"Starting OpenAI Moderation Benchmark Suite (Async): {suite_run_id}")
    logger.info(f"Input CSVs: {args.input_csvs}")
    logger.info(f"Outputting results to: {output_results_file}")
    logger.info(f"Outputting report to: {output_report_file}")
    logger.info(f"Using OpenAI Moderation model: {args.openai_model}")
    logger.info(f"Concurrency limit: {args.concurrency_limit}, Max retries: {args.max_retries}")
    logger.info(f"Initial backoff: {args.initial_backoff_secs}s, Backoff factor: {args.backoff_factor}")

    try:
        moderator = OpenAIModerator(
            api_key=openai_api_key, model_id=args.openai_model, logger=logger,
            max_retries=args.max_retries, 
            initial_backoff_secs=args.initial_backoff_secs, 
            backoff_factor=args.backoff_factor
        )
    except ValueError as e:
        logger.critical(f"Failed to initialize OpenAIModerator: {e}")
        sys.exit(1)
    except Exception as e:
        logger.critical(f"An unexpected error occurred during OpenAIModerator initialization: {e}", exc_info=True)
        sys.exit(1)

    all_prompts: List[InputPromptRecord] = load_and_filter_prompts(
        csv_file_paths=args.input_csvs,
        split_to_filter=SPLIT_BENCHMARK,
        load_prompts_func=load_prompts_from_csv,
        logger=logger
    )
    
    if not all_prompts:
        logger.warning("No benchmark prompts loaded. Exiting.")
        sys.exit(0)
    
    individual_results_collected: List[OpenAIModerationSingleResult] = []
    task_coroutines = []
    semaphore = asyncio.Semaphore(args.concurrency_limit)
    
    async with aiohttp.ClientSession() as session:
        for idx, prompt_record in enumerate(all_prompts):
            task_coroutines.append(
                process_single_prompt_for_gather(
                    prompt_record, moderator, session, logger, semaphore, idx_for_fallback_id=idx
                )
            )
        
        results_or_exceptions = []
        for f in tqdm(asyncio.as_completed(task_coroutines), total=len(task_coroutines), desc="Processing prompts (OpenAI Mod Async)"):
            try:
                result = await f
                results_or_exceptions.append(result)
            except Exception as e: 
                logger.error(f"Task coroutine raised an unexpected exception: {e}", exc_info=True)
                results_or_exceptions.append(e) 

    for res_or_exc in results_or_exceptions:
        if isinstance(res_or_exc, OpenAIModerationSingleResult):
            individual_results_collected.append(res_or_exc)
        elif isinstance(res_or_exc, Exception):
            logger.error(f"Encountered an exception object in results: {res_or_exc}. This might indicate a bug.")

    overall_metrics = calculate_common_metrics(
        individual_results_collected, logger,
        positive_label=FortressDecisionLabel.UNSAFE.value,
        negative_label=FortressDecisionLabel.SAFE.value,
        error_label=FortressDecisionLabel.ERROR.value,
        ambiguous_label=None
    )

    timestamp_end_dt = datetime.now()
    duration_seconds = (timestamp_end_dt - timestamp_start_dt).total_seconds()

    suite_results_data = OpenAIModerationSuiteResults(
        suite_run_id=suite_run_id,
        timestamp_start=timestamp_start_dt.isoformat(),
        timestamp_end=timestamp_end_dt.isoformat(),
        duration_seconds=duration_seconds,
        input_csv_files=args.input_csvs,
        output_results_file=str(output_results_file.resolve()),
        output_report_file=str(output_report_file.resolve()),
        openai_model_id=args.openai_model,
        metrics=overall_metrics,
        individual_results=individual_results_collected,
    )

    save_json_output(suite_results_data, output_results_file, logger)

    model_specific_desc = (f"OpenAI Moderation API (Model: `{args.openai_model}`), "
                           f"Concurrency: {args.concurrency_limit}, Max Retries: {args.max_retries}")
    report_content = generate_markdown_report_parts(
        suite_results=suite_results_data,
        model_specific_desc=model_specific_desc,
        report_title="OpenAI Moderation Benchmark Report (Async)"
    )
    save_markdown_report(report_content, output_report_file, logger)

    logger.info("OpenAI Moderation Benchmark Suite (Async) finished.")
                           
    logger.info(f"Total prompts processed: {overall_metrics.num_samples}, Errors: {overall_metrics.num_pred_error}")
    logger.info(f"Accuracy: {overall_metrics.accuracy:.4f}, F1 (Unsafe): {overall_metrics.f1_unsafe if overall_metrics.f1_unsafe is not None else 'N/A'}")
    logger.info(f"Duration: {duration_seconds:.2f} seconds")

                                                          
def main():
    asyncio.run(main_async_logic())

if __name__ == "__main__":
    main()
