import argparse
import json                                                                     
import logging                                
import os
import sys

import time
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple

from pydantic import BaseModel, Field

                                                                    
                                                                                                                               
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))                                     

try:
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
except ImportError:
    print("torch or transformers is not installed. Please install them: pip install torch transformers")
    sys.exit(1)

                                                                                 
try:
                                                         
    pass 
except ImportError:
    print("scikit-learn is not installed. Please install it to run the benchmark suite: pip install scikit-learn")
    sys.exit(1)

                           
from fortress.common.data_models import DecisionLabel as FortressDecisionLabel                                  
from fortress.data_management.data_loader import load_prompts_from_csv
from fortress.common.constants import SPLIT_BENCHMARK, LABEL_SAFE, LABEL_UNSAFE                           

                             
from scripts.utils.benchmark_utils import (
    setup_logger,
    BaseBenchmarkSingleResult,
    CommonBenchmarkMetrics,                                 
                                                          
    add_common_benchmark_args,
    create_run_id,
    prepare_output_paths,
    load_and_filter_prompts,
    map_true_label_to_str,
    calculate_common_metrics,
    generate_markdown_report_parts,
    save_json_output,
    save_markdown_report,
    get_progress_bar,
    InputPromptRecord                                                                                                      
)

                                                                             

class LlamaGuardBenchmarkSingleResult(BaseBenchmarkSingleResult):
    llama_guard_raw_output: Optional[str] = None
    violated_categories: Optional[List[str]] = None
                                                                                             
                                                                                        

class LlamaGuardBenchmarkSuiteResults(BaseModel):                                                                                           
    suite_run_id: str
    timestamp_start: str
    timestamp_end: str
    duration_seconds: float
    input_csv_files: List[str]
    output_results_file: str                 
    output_report_file: str                 
    llama_guard_model_id: str
    metrics: CommonBenchmarkMetrics                               
    individual_results: List[LlamaGuardBenchmarkSingleResult]


                                               

def parse_llama_guard_output(output: str, logger: logging.Logger) -> Tuple[str, List[str]]:
    """
    Parses Llama Guard's raw output string.
    Returns: (predicted_label: str, violated_categories: List[str])
    predicted_label can be "SAFE", "UNSAFE", "ERROR" (from FortressDecisionLabel).
    """
    output_clean = output.strip()

    if output_clean == "safe":
        return FortressDecisionLabel.SAFE.value, []

    lines = output_clean.splitlines()
    if len(lines) == 1 and '\\n' in lines[0]:                                
        lines = lines[0].split('\\n')
    
    processed_lines = [line.strip() for line in lines if line.strip()]

    if not processed_lines:
        logger.warning(f"Empty or malformed Llama Guard output after processing: '{output_clean}'. Defaulting to ERROR.")
        return FortressDecisionLabel.ERROR.value, []

    first_line_content = processed_lines[0]

    if first_line_content == "unsafe":
        categories = [line for line in processed_lines[1:] if line.startswith("S")]
        return FortressDecisionLabel.UNSAFE.value, categories
    
    if first_line_content.startswith("unsafeS") and len(first_line_content) > len("unsafe"):
        potential_category = first_line_content[len("unsafe"):]
        if potential_category.startswith("S"):
            logger.info(
                f"Llama Guard output '{output_clean}' (processed to '{first_line_content}') "
                f"parsed as UNSAFE with category '{potential_category}' due to concatenation."
            )
            return FortressDecisionLabel.UNSAFE.value, [potential_category]

    logger.warning(
        f"Unexpected Llama Guard output format: '{output_clean}'. "
        f"Processed lines: {processed_lines}. Defaulting to ERROR."
    )
    return FortressDecisionLabel.ERROR.value, []


                                  
class LlamaGuardModerator:
    def __init__(self, model_id: str, device_str: str, dtype_str: str, logger: logging.Logger):
        self.model_id = model_id
        self.logger = logger
        self.device = torch.device(device_str if torch.cuda.is_available() and device_str == "cuda" else "cpu")
        
        if dtype_str == "bfloat16" and self.device.type == 'cuda' and not torch.cuda.is_bf16_supported():
            self.logger.warning("bfloat16 is not supported on this CUDA device. Falling back to float16.")
            self.dtype = torch.float16
        elif dtype_str == "bfloat16":
            self.dtype = torch.bfloat16
        elif dtype_str == "float16":
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32
            if self.device.type == 'cuda':
                 self.logger.info("Using float32 on CUDA. Consider bfloat16 or float16 for better performance if supported.")

        self.logger.info(f"Loading Llama Guard model: {model_id} onto device: {self.device} with dtype: {self.dtype}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
                                                                
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id                  

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=self.dtype,
            device_map=self.device if self.device.type == 'cuda' else "auto"                                                  
        )
                                                                                    
        if self.device.type == 'cpu' and not (isinstance(self.model.device_map, dict) or self.model.device.type == 'cuda'):
             self.model.to(self.device)                                                        
        self.model.eval()


    def moderate_prompt(self, prompt_text: str) -> str:
        chat = [{"role": "user", "content": prompt_text}]
        try:
            input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to(self.device)
                                                  
            output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=self.tokenizer.pad_token_id)
            prompt_len = input_ids.shape[-1]
            return self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
        except Exception as e:
            self.logger.error(f"Error during Llama Guard inference for prompt '{prompt_text[:50]}...': {e}", exc_info=True)
            return "ERROR_INFERENCE"

                        
def main():
    parser = argparse.ArgumentParser(description="Run Llama Guard Benchmark Suite.")
                          
    parser = add_common_benchmark_args(parser)
    
                                   
    parser.add_argument(
        "--model-id",
        type=str,
        default="meta-llama/Llama-Guard-3-11B-Vision",
        help="Hugging Face model ID for Llama Guard. (Default: meta-llama/Llama-Guard-3-8B)"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "cpu"],
        help="Device to run the model on ('cuda' or 'cpu'). (Default: cuda)"
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["bfloat16", "float16", "float32"],
        help="Data type for model parameters (e.g., bfloat16, float16, float32). (Default: bfloat16)"
    )
    args = parser.parse_args()

    logger = setup_logger(__name__, level_str=args.log_level)

    timestamp_start_dt = datetime.now()
    suite_run_id = create_run_id(base_name=args.model_id.split('/')[-1], prefix=args.run_name_prefix or "llamaguard")
    
    output_results_file, output_report_file, _ = prepare_output_paths(
        args.output_dir, suite_run_id
    )

    logger.info(f"Starting Llama Guard Benchmark Suite: {suite_run_id}")
    logger.info(f"Input CSVs: {args.input_csvs}")
    logger.info(f"Outputting results to: {output_results_file}")
    logger.info(f"Outputting report to: {output_report_file}")
    logger.info(f"Using Llama Guard model: {args.model_id} on device: {args.device} with dtype: {args.dtype}")

                            
    try:
        moderator = LlamaGuardModerator(
            model_id=args.model_id, 
            device_str=args.device, 
            dtype_str=args.dtype,
            logger=logger
        )
    except Exception as e:
        logger.critical(f"Failed to initialize Llama Guard model: {e}", exc_info=True)
        sys.exit(1)

                                             
    all_prompts: List[InputPromptRecord] = load_and_filter_prompts(
        csv_file_paths=args.input_csvs,
        split_to_filter=SPLIT_BENCHMARK,
        load_prompts_func=load_prompts_from_csv,                         
        logger=logger
    )
    
    if not all_prompts:
        logger.warning("No benchmark prompts loaded. Exiting.")
        sys.exit(0)
    
    individual_results: List[LlamaGuardBenchmarkSingleResult] = []

                               
    for i, prompt_record in enumerate(get_progress_bar(all_prompts, desc="Processing prompts (LlamaGuard)", logger=logger)):
                                                                                                                       
        
        true_label_str = map_true_label_to_str(
            prompt_record.label, logger, 
            safe_numeric_val=LABEL_SAFE, unsafe_numeric_val=LABEL_UNSAFE
        )
        
        start_time = time.perf_counter()
        raw_output = moderator.moderate_prompt(prompt_record.original_prompt)
        end_time = time.perf_counter()
        processing_time_ms = (end_time - start_time) * 1000

        predicted_label_str = FortressDecisionLabel.ERROR.value
        violated_categories_list = []
        error_info_str: Optional[str] = None

        if raw_output == "ERROR_INFERENCE":
            error_info_str = "Llama Guard inference failed."
        else:
            predicted_label_str, violated_categories_list = parse_llama_guard_output(raw_output, logger)
            if predicted_label_str == FortressDecisionLabel.ERROR.value and not error_info_str:
                 error_info_str = f"Could not parse Llama Guard output (first 100 chars): {raw_output[:100]}"

        single_result = LlamaGuardBenchmarkSingleResult(
            prompt_id=str(prompt_record.prompt_id) if prompt_record.prompt_id else f"csvidx_{prompt_record.csv_row_number if hasattr(prompt_record, 'csv_row_number') else i}",
            original_prompt=prompt_record.original_prompt,
            true_label=true_label_str,
            predicted_label=predicted_label_str,
            llama_guard_raw_output=raw_output,
            violated_categories=violated_categories_list,
            processing_time_ms=processing_time_ms,
            source_file_input=prompt_record.source_file,
            prompt_category_input=prompt_record.prompt_category,
            prompt_style_input=prompt_record.prompt_style,
            error_info=error_info_str
        )
        individual_results.append(single_result)

                                                          
    overall_metrics = calculate_common_metrics(
        individual_results, logger,
        positive_label=FortressDecisionLabel.UNSAFE.value,
        negative_label=FortressDecisionLabel.SAFE.value,
        error_label=FortressDecisionLabel.ERROR.value,
        ambiguous_label=None                                    
    )

    timestamp_end_dt = datetime.now()
    duration_seconds = (timestamp_end_dt - timestamp_start_dt).total_seconds()

    suite_results = LlamaGuardBenchmarkSuiteResults(
        suite_run_id=suite_run_id,
        timestamp_start=timestamp_start_dt.isoformat(),
        timestamp_end=timestamp_end_dt.isoformat(),
        duration_seconds=duration_seconds,
        input_csv_files=args.input_csvs,
        output_results_file=str(output_results_file.resolve()),
        output_report_file=str(output_report_file.resolve()),
        llama_guard_model_id=args.model_id,
        metrics=overall_metrics,
        individual_results=individual_results,
    )

                                                     
    save_json_output(suite_results, output_results_file, logger)

                                                                  
    model_specific_desc = f"Llama Guard Model: `{args.model_id}` (Device: {args.device}, Dtype: {args.dtype})"
    report_content = generate_markdown_report_parts(
        suite_results=suite_results,
        model_specific_desc=model_specific_desc,
        report_title="Llama Guard Benchmark Report"
    )
    save_markdown_report(report_content, output_report_file, logger)

    logger.info("Llama Guard Benchmark Suite finished.")
    logger.info(f"Accuracy: {overall_metrics.accuracy:.4f}, F1 (Unsafe): {overall_metrics.f1_unsafe if overall_metrics.f1_unsafe is not None else 'N/A'}")

if __name__ == "__main__":
    main()
