import argparse
import logging
import os
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple

from pydantic import BaseModel

                                                                    
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

try:
    from vllm import LLM, SamplingParams
except ImportError:
    print("vllm is not installed. Please install it to run this benchmark: pip install vllm")
    sys.exit(1)

                           
from fortress.common.data_models import DecisionLabel as FortressDecisionLabel
from fortress.data_management.data_loader import load_prompts_from_csv
from fortress.common.constants import SPLIT_BENCHMARK

                             
from scripts.utils.benchmark_utils import (
    setup_logger,
    BaseBenchmarkSingleResult,
    CommonBenchmarkMetrics,
    add_common_benchmark_args,
    create_run_id,
    prepare_output_paths,
    load_and_filter_prompts,
    map_true_label_to_str,
    calculate_common_metrics,
    generate_markdown_report_parts,
    save_json_output,
    save_markdown_report,
    get_progress_bar,
    InputPromptRecord
)

                                               

class GuardReasonerBenchmarkSingleResult(BaseBenchmarkSingleResult):
    guardreasoner_raw_output: Optional[str] = None
    violated_categories: Optional[List[str]] = None

class GuardReasonerBenchmarkSuiteResults(BaseModel):
    suite_run_id: str
    timestamp_start: str
    timestamp_end: str
    duration_seconds: float
    input_csv_files: List[str]
    output_results_file: str
    output_report_file: str
    guardreasoner_model_id: str
    metrics: CommonBenchmarkMetrics
    individual_results: List[GuardReasonerBenchmarkSingleResult]


                                                 

def parse_guardreasoner_output(output: str, logger: logging.Logger) -> Tuple[str, List[str]]:
    """
    Parses GuardReasoner's raw output string.
    It looks for a line starting with 'Request:' and parses the label from there.
    It also extracts reasoning steps as violated categories if the request is harmful/unsafe.
    Returns: (predicted_label: str, violated_categories: List[str])
    """
    output_clean = output.strip()
    if not output_clean:
        logger.warning("Empty GuardReasoner output. Defaulting to ERROR.")
        return FortressDecisionLabel.ERROR.value, []

    lines = output_clean.splitlines()
    
    request_label_str = None
    
                                                                                       
    for line in reversed(lines):
        line_lower = line.lower()
        if line_lower.startswith("request:"):
            request_label_str = line[len("request:"):].strip().lower()
            break
            
    if request_label_str:
        predicted_label = FortressDecisionLabel.ERROR.value
        categories = []

        if "harmful" in request_label_str or "unsafe" in request_label_str:
            predicted_label = FortressDecisionLabel.UNSAFE.value
                                                                    
            for i, l in enumerate(lines):
                if "reasoning step" in l.lower():
                                                              
                    if i + 1 < len(lines) and lines[i+1].strip():
                        next_line = lines[i+1].strip()
                                                                                   
                        if not next_line.startswith("#") and not next_line.startswith("---") and "request:" not in next_line.lower():
                             categories.append(next_line)
        elif "safe" in request_label_str:
            predicted_label = FortressDecisionLabel.SAFE.value
        
        if predicted_label != FortressDecisionLabel.ERROR.value:
            return predicted_label, categories

                                                                            
    logger.warning(f"Unexpected GuardReasoner output format: Could not find or parse 'Request:' line in output: '{output}'. Defaulting to ERROR.")
    return FortressDecisionLabel.ERROR.value, []


                                   
class GuardReasonerModerator:
    def __init__(self, model_size: str, gpu_memory_utilization: float, max_num_seqs: int, temperature: float, top_p: float, max_tokens: int, max_model_len: int, logger: logging.Logger):
        self.model_id = f"yueliu1999/GuardReasoner-{model_size}"
        self.logger = logger
        
        self.logger.info(f"Loading GuardReasoner model: {self.model_id}")
        self.llm = LLM(
            model=self.model_id, 
            gpu_memory_utilization=gpu_memory_utilization, 
            max_num_seqs=max_num_seqs,
            max_model_len=max_model_len                              
        )
        self.sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)
        self.logger.info("GuardReasoner model loaded.")

    def moderate_prompts(self, prompts: List[str]) -> List[str]:
        self.logger.info(f"Moderating {len(prompts)} prompts in a batch.")
        try:
            outputs = self.llm.generate(prompts, self.sampling_params)
            return [output.outputs[0].text for output in outputs]
        except Exception as e:
            self.logger.error(f"Error during GuardReasoner inference: {e}", exc_info=True)
            return ["ERROR_INFERENCE"] * len(prompts)

                       
def main():
    parser = argparse.ArgumentParser(description="Run GuardReasoner Benchmark Suite.")
    parser = add_common_benchmark_args(parser)
    
                                      
    parser.add_argument("--model-size", type=str, default="1B", choices=["1B", "3B", "8B"], help="Size of the GuardReasoner model to use.")
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.9, help="GPU memory utilization for vLLM.")
    parser.add_argument("--max-num-seqs", type=int, default=256, help="Maximum number of sequences for vLLM.")
    parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
    parser.add_argument("--top-p", type=float, default=1.0, help="Sampling top-p.")
    parser.add_argument("--max-tokens", type=int, default=2048, help="Maximum number of tokens to generate.")
                                         
    parser.add_argument("--max-model-len", type=int, default=8192, help="Maximum model length for vLLM to prevent memory errors.")
    
    args = parser.parse_args()

    logger = setup_logger(__name__, level_str=args.log_level)

    timestamp_start_dt = datetime.now()
    suite_run_id = create_run_id(base_name=f"guardreasoner_{args.model_size}", prefix=args.run_name_prefix or "guardreasoner")
    
    output_results_file, output_report_file, _ = prepare_output_paths(
        args.output_dir, suite_run_id
    )

    logger.info(f"Starting GuardReasoner Benchmark Suite: {suite_run_id}")
    logger.info(f"Input CSVs: {args.input_csvs}")
    logger.info(f"Outputting results to: {output_results_file}")
    logger.info(f"Outputting report to: {output_report_file}")
    logger.info(f"Using GuardReasoner model size: {args.model_size}")

    try:
        moderator = GuardReasonerModerator(
            model_size=args.model_size,
            gpu_memory_utilization=args.gpu_memory_utilization,
            max_num_seqs=args.max_num_seqs,
            temperature=args.temperature,
            top_p=args.top_p,
            max_tokens=args.max_tokens,
            max_model_len=args.max_model_len,                              
            logger=logger
        )
    except Exception as e:
        logger.critical(f"Failed to initialize GuardReasoner model: {e}", exc_info=True)
        sys.exit(1)

    all_prompts: List[InputPromptRecord] = load_and_filter_prompts(
        csv_file_paths=args.input_csvs,
        split_to_filter=SPLIT_BENCHMARK,
        load_prompts_func=load_prompts_from_csv,
        logger=logger
    )
    
    if not all_prompts:
        logger.warning("No benchmark prompts loaded. Exiting.")
        sys.exit(0)
    
    prompts_to_process = [p.original_prompt for p in all_prompts]
    individual_results: List[GuardReasonerBenchmarkSingleResult] = []

    start_time = time.time()
    raw_outputs = moderator.moderate_prompts(prompts_to_process)
    end_time = time.time()

    total_processing_time_ms = (end_time - start_time) * 1000
    avg_processing_time_ms = total_processing_time_ms / len(prompts_to_process) if prompts_to_process else 0
    logger.info(f"Batch processing finished. Total time: {total_processing_time_ms:.2f} ms, Avg per prompt: {avg_processing_time_ms:.2f} ms")

    for i, prompt_record in enumerate(get_progress_bar(all_prompts, desc="Parsing results", logger=logger)):
        raw_output = raw_outputs[i]
        error_info = None

        if raw_output == "ERROR_INFERENCE":
            predicted_label = FortressDecisionLabel.ERROR.value
            violated_categories = []
            error_info = "Inference failed for this prompt."
        else:
            predicted_label, violated_categories = parse_guardreasoner_output(raw_output, logger)

        true_label_str = map_true_label_to_str(prompt_record.label, logger)

        result = GuardReasonerBenchmarkSingleResult(
            prompt_id=str(prompt_record.prompt_id) if prompt_record.prompt_id else f"record_{i}",
            original_prompt=prompt_record.original_prompt,
            true_label=true_label_str,
            predicted_label=predicted_label,
            processing_time_ms=avg_processing_time_ms,
            source_file_input=prompt_record.source_file,
            prompt_category_input=prompt_record.prompt_category,
            prompt_style_input=prompt_record.prompt_style,
            error_info=error_info,
            guardreasoner_raw_output=raw_output,
            violated_categories=violated_categories,
        )
        individual_results.append(result)

    overall_metrics = calculate_common_metrics(
        individual_results, logger,
        positive_label=FortressDecisionLabel.UNSAFE.value,
        negative_label=FortressDecisionLabel.SAFE.value,
        error_label=FortressDecisionLabel.ERROR.value,
    )

    timestamp_end_dt = datetime.now()
    duration_seconds = (timestamp_end_dt - timestamp_start_dt).total_seconds()

    suite_results = GuardReasonerBenchmarkSuiteResults(
        suite_run_id=suite_run_id,
        timestamp_start=timestamp_start_dt.isoformat(),
        timestamp_end=timestamp_end_dt.isoformat(),
        duration_seconds=duration_seconds,
        input_csv_files=args.input_csvs,
        output_results_file=str(output_results_file.resolve()),
        output_report_file=str(output_report_file.resolve()),
        guardreasoner_model_id=moderator.model_id,
        metrics=overall_metrics,
        individual_results=individual_results,
    )

    save_json_output(suite_results, output_results_file, logger)

    model_specific_desc = f"GuardReasoner Model: `{moderator.model_id}`"
    report_content = generate_markdown_report_parts(
        suite_results=suite_results,
        model_specific_desc=model_specific_desc,
        report_title="GuardReasoner Benchmark Report"
    )
    save_markdown_report(report_content, output_report_file, logger)

    logger.info("GuardReasoner Benchmark Suite finished.")
    logger.info(f"Accuracy: {overall_metrics.accuracy:.4f}, F1 (Unsafe): {overall_metrics.f1_unsafe if overall_metrics.f1_unsafe is not None else 'N/A'}")

if __name__ == "__main__":
    main()
