import argparse
import json
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple, Union, Callable

from pydantic import BaseModel, Field
from tqdm import tqdm

                                                                                              
try:
    from sklearn.metrics import (
        accuracy_score,
        precision_score,
        recall_score,
        f1_score,
        confusion_matrix as sk_confusion_matrix,
    )
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
                                                                               
    def accuracy_score(*args, **kwargs): raise ImportError("scikit-learn is required for metrics.")
    def precision_score(*args, **kwargs): raise ImportError("scikit-learn is required for metrics.")
    def recall_score(*args, **kwargs): raise ImportError("scikit-learn is required for metrics.")
    def f1_score(*args, **kwargs): raise ImportError("scikit-learn is required for metrics.")
    def sk_confusion_matrix(*args, **kwargs): raise ImportError("scikit-learn is required for metrics.")

                                                                       
                                                                                                
                                                          
class DecisionLabel:
    SAFE = "SAFE"
    UNSAFE = "UNSAFE"
    ERROR = "ERROR"
    AMBIGUOUS = "AMBIGUOUS"

class InputPromptRecord(BaseModel):                                                
    prompt_id: Optional[Union[str, int]] = None
    original_prompt: str
    label: Any
    split: Optional[str] = None
    source_file: Optional[str] = None
    prompt_category: Optional[str] = None
    prompt_style: Optional[str] = None
    
    class Config:
        extra = 'allow'                              

                      
def setup_logger(name: str, level_str: str = "INFO") -> logging.Logger:
    level = getattr(logging, level_str.upper(), logging.INFO)
    logger = logging.getLogger(name)
    if not logger.handlers:                                                    
        logging.basicConfig(level=level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    else:                                            
        logger.setLevel(level)
        for handler in logger.handlers:                                                              
            handler.setLevel(level)
    return logger

                                                         
class BaseBenchmarkSingleResult(BaseModel):
    prompt_id: Optional[str] = None
    original_prompt: str
    true_label: str
    predicted_label: str
    processing_time_ms: float
    source_file_input: Optional[str] = None
    prompt_category_input: Optional[str] = None
    prompt_style_input: Optional[str] = None
    error_info: Optional[str] = None

class CommonBenchmarkMetrics(BaseModel):
    num_samples: int
    accuracy: float
    precision_unsafe: Optional[float] = None
    recall_unsafe: Optional[float] = None
    f1_unsafe: Optional[float] = None
    fpr_unsafe: Optional[float] = None
    fnr_unsafe: Optional[float] = None
    confusion_matrix_values: Dict[str, int]
    num_true_safe: int
    num_true_unsafe: int
    num_pred_safe: int
    num_pred_unsafe: int
    num_pred_error: int
    num_pred_ambiguous: Optional[int] = None

class BaseBenchmarkSuiteResults(BaseModel):
    suite_run_id: str
    timestamp_start: str
    timestamp_end: str
    duration_seconds: float
    input_csv_files: List[str]
    output_results_file: str
    output_report_file: str
    metrics: CommonBenchmarkMetrics
                                                                                                                 

                              
def add_common_benchmark_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    parser.add_argument(
        "--input-csvs",
        type=str,
        required=True,
        nargs='+',
        help="Path(s) to the input CSV file(s) containing prompts."
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="benchmarks",
        help="Directory to save benchmark results and reports. (Default: benchmarks)"
    )
    parser.add_argument(
        "--run-name-prefix",
        type=str,
        default=None,
        help="A custom prefix for the benchmark run ID. (Optional)"
    )
    parser.add_argument(
        "--log-level",
        type=str,
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        help="Set the logging level. (Default: INFO)",
    )
    return parser

def create_run_id(base_name: str, prefix: Optional[str] = None) -> str:
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    clean_base_name = base_name.replace("-", "_").replace("/", "_").lower()
    if prefix:
        return f"{prefix.lower()}_{clean_base_name}_{timestamp_str}"
    return f"{clean_base_name}_{timestamp_str}"

def prepare_output_paths(
    output_dir_str: str,
    suite_run_id: str,
    results_filename_suffix: str = "_results.json",
    report_filename_suffix: str = "_report.md"
) -> Tuple[Path, Path, Path]:
    output_dir = Path(output_dir_str)
    results_data_dir = output_dir / "results_data"
    reports_dir = output_dir / "reports"
    results_data_dir.mkdir(parents=True, exist_ok=True)
    reports_dir.mkdir(parents=True, exist_ok=True)

    output_results_file = results_data_dir / f"{suite_run_id}{results_filename_suffix}"
    output_report_file = reports_dir / f"{suite_run_id}{report_filename_suffix}"
    return output_results_file, output_report_file, output_dir

                                        
def load_and_filter_prompts(
    csv_file_paths: List[str],
    split_to_filter: str,
    load_prompts_func: Callable[[str], List[InputPromptRecord]],
    logger: logging.Logger
) -> List[InputPromptRecord]:
    all_prompts_for_split: List[InputPromptRecord] = []
    for csv_file in csv_file_paths:
        logger.info(f"Loading prompts from: {csv_file}")
        try:
            prompts_from_file = load_prompts_func(csv_file)
            benchmark_prompts = [
                p for p in prompts_from_file if getattr(p, 'split', None) == split_to_filter
            ]
            all_prompts_for_split.extend(benchmark_prompts)
            logger.info(f"Loaded {len(prompts_from_file)} total prompts, found {len(benchmark_prompts)} for '{split_to_filter}' split from {csv_file}.")
        except Exception as e:
            logger.error(f"Failed to load or filter prompts from {csv_file}: {e}", exc_info=True)
    
    if not all_prompts_for_split:
        logger.warning(f"No prompts found for split '{split_to_filter}' across all provided CSV files. Execution might halt or produce empty results.")
    else:
        logger.info(f"Total prompts for split '{split_to_filter}' to be processed: {len(all_prompts_for_split)}")
    return all_prompts_for_split

def map_true_label_to_str(
    label_val: Any,
    logger: logging.Logger,
    safe_numeric_val: int = 0, 
    unsafe_numeric_val: int = 1
) -> str:
    """Maps various input true label formats to 'SAFE', 'UNSAFE', or 'N/A'."""
                                                           
    safe_str = DecisionLabel.SAFE
    unsafe_str = DecisionLabel.UNSAFE

    if isinstance(label_val, str):
        val_norm = label_val.strip().upper()
        if val_norm == safe_str: return safe_str
        if val_norm == unsafe_str: return unsafe_str
        try: 
            int_val = int(val_norm)
            if int_val == safe_numeric_val: return safe_str
            if int_val == unsafe_numeric_val: return unsafe_str
        except ValueError:
            pass 
    elif isinstance(label_val, int):
        if label_val == safe_numeric_val: return safe_str
        if label_val == unsafe_numeric_val: return unsafe_str
    
    logger.warning(
        f"Unmappable true label value: '{label_val}' (type: {type(label_val)}). "
        f"Expected values like {safe_numeric_val}, {unsafe_numeric_val}, '{safe_str}', or '{unsafe_str}'. "
        "Defaulting to 'N/A'."
    )
    return "N/A"

                             
def calculate_common_metrics(
    results: List[BaseBenchmarkSingleResult], 
    logger: logging.Logger,
    positive_label: str = DecisionLabel.UNSAFE,
    negative_label: str = DecisionLabel.SAFE,
    error_label: str = DecisionLabel.ERROR,
    ambiguous_label: Optional[str] = DecisionLabel.AMBIGUOUS
) -> CommonBenchmarkMetrics:
    if not SKLEARN_AVAILABLE:
        logger.error("scikit-learn is not installed. Cannot calculate metrics.")
        return CommonBenchmarkMetrics(
            num_samples=len(results) if results else 0, accuracy=0.0,
            confusion_matrix_values={"TN":0,"FP":0,"FN":0,"TP":0},
            num_true_safe=0, num_true_unsafe=0, num_pred_safe=0,
            num_pred_unsafe=0, num_pred_error=len(results) if results else 0,
            num_pred_ambiguous=0 if ambiguous_label else None
        )
        
    if not results:
        logger.warning("No results provided to calculate_common_metrics. Returning empty metrics.")
        return CommonBenchmarkMetrics(
            num_samples=0, accuracy=0.0,
            confusion_matrix_values={"TN":0,"FP":0,"FN":0,"TP":0},
            num_true_safe=0, num_true_unsafe=0, num_pred_safe=0,
            num_pred_unsafe=0, num_pred_error=0,
            num_pred_ambiguous=0 if ambiguous_label else None
        )

    true_labels_all = [r.true_label for r in results]
    pred_labels_all = [r.predicted_label for r in results]
    
    labels_for_sklearn_cm = [negative_label, positive_label]

                                                                                                                   
                                                                                                
    mapped_pred_labels = []
    mapped_true_labels = []
    for t, p in zip(true_labels_all, pred_labels_all):
        if t in labels_for_sklearn_cm:
            mapped_true_labels.append(t)
            if p in labels_for_sklearn_cm:
                mapped_pred_labels.append(p)
            else:
                                                                                                   
                mapped_pred_labels.append("__OTHER__")
                                                                                                           
    y_true_binary_sklearn = mapped_true_labels
    y_pred_binary_sklearn = mapped_pred_labels

    num_samples_total = len(results)
    
    correct_classifications_for_accuracy = 0
    num_valid_true_labels_for_accuracy = 0
    for r_item in results:
        if r_item.true_label in labels_for_sklearn_cm:
            num_valid_true_labels_for_accuracy += 1
            if r_item.true_label == r_item.predicted_label:
                correct_classifications_for_accuracy += 1
    
    accuracy_val = (correct_classifications_for_accuracy / num_valid_true_labels_for_accuracy)\
                   if num_valid_true_labels_for_accuracy > 0 else 0.0
    
    num_pred_error_val = sum(1 for p in pred_labels_all if p == error_label)
    num_pred_ambiguous_val = sum(1 for p in pred_labels_all if ambiguous_label and p == ambiguous_label) if ambiguous_label else None
    num_pred_safe_val = sum(1 for p in pred_labels_all if p == negative_label)
    num_pred_unsafe_val = sum(1 for p in pred_labels_all if p == positive_label)

    num_true_safe_val = sum(1 for t in true_labels_all if t == negative_label)
    num_true_unsafe_val = sum(1 for t in true_labels_all if t == positive_label)

    tn, fp, fn, tp = 0, 0, 0, 0
    precision_unsafe_val, recall_unsafe_val, f1_unsafe_val = None, None, None
    fpr_unsafe_val, fnr_unsafe_val = None, None

                                                                                                                 
    if y_true_binary_sklearn:
                                                                                            
        sklearn_labels = labels_for_sklearn_cm + ["__OTHER__"]
                                                                                                
                                                                                                                 
        y_pred_for_metrics = [
            p if p in labels_for_sklearn_cm else negative_label
            for p in y_pred_binary_sklearn
        ]
        precision_unsafe_val = precision_score(
            y_true_binary_sklearn, y_pred_for_metrics,
            labels=labels_for_sklearn_cm, pos_label=positive_label, zero_division=0
        )
        recall_unsafe_val = recall_score(
            y_true_binary_sklearn, y_pred_for_metrics,
            labels=labels_for_sklearn_cm, pos_label=positive_label, zero_division=0
        )
        f1_unsafe_val = f1_score(
            y_true_binary_sklearn, y_pred_for_metrics,
            labels=labels_for_sklearn_cm, pos_label=positive_label, zero_division=0
        )
                                                                                       
        cm_sklearn = sk_confusion_matrix(
            y_true_binary_sklearn, y_pred_for_metrics, labels=labels_for_sklearn_cm
        )
        if cm_sklearn.size == 4:
            tn_cm, fp_cm, fn_cm, tp_cm = cm_sklearn.ravel()
            tn, fp, fn, tp = int(tn_cm), int(fp_cm), int(fn_cm), int(tp_cm)
        elif cm_sklearn.size == 1 and len(set(y_true_binary_sklearn)) == 1:
            the_only_class = y_true_binary_sklearn[0]
            count = len(y_true_binary_sklearn)
            if the_only_class == negative_label: tn = count
            elif the_only_class == positive_label: tp = count
        else:
            logger.warning(f"Unexpected confusion matrix shape: {cm_sklearn.shape}. CM values (tn, fp, fn, tp) will be 0.")

        fpr_unsafe_val = fp / (fp + tn) if (fp + tn) > 0 else 0.0
        fnr_unsafe_val = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    else:
        logger.warning("No valid samples for binary classification metrics (precision, recall, F1, CM).")

    metrics_payload = {
        "num_samples": num_samples_total,
        "accuracy": accuracy_val,
        "precision_unsafe": precision_unsafe_val,
        "recall_unsafe": recall_unsafe_val,
        "f1_unsafe": f1_unsafe_val,
        "fpr_unsafe": fpr_unsafe_val,
        "fnr_unsafe": fnr_unsafe_val,
        "confusion_matrix_values": {"TN": tn, "FP": fp, "FN": fn, "TP": tp},
        "num_true_safe": num_true_safe_val,
        "num_true_unsafe": num_true_unsafe_val,
        "num_pred_safe": num_pred_safe_val,
        "num_pred_unsafe": num_pred_unsafe_val,
        "num_pred_error": num_pred_error_val,
    }
    if ambiguous_label is not None:
        metrics_payload["num_pred_ambiguous"] = num_pred_ambiguous_val
        
    return CommonBenchmarkMetrics(**metrics_payload)

                           
def generate_markdown_report_parts(
    suite_results: BaseBenchmarkSuiteResults, 
    model_specific_desc: str, 
    model_specific_config_lines: Optional[List[str]] = None, 
    report_title: str = "Benchmark Report"
) -> str:
    metrics = suite_results.metrics
    cm = metrics.confusion_matrix_values
    
                                                                            
    safe_label_str = DecisionLabel.SAFE
    unsafe_label_str = DecisionLabel.UNSAFE
    error_label_str = DecisionLabel.ERROR
    ambiguous_label_str = DecisionLabel.AMBIGUOUS
    
    report_content = [
        f"# {report_title}\n",
        f"**Suite Run ID:** {suite_results.suite_run_id}",
        f"**System/Model:** {model_specific_desc}",
        f"**Timestamp Start:** {suite_results.timestamp_start}",
        f"**Timestamp End:** {suite_results.timestamp_end}",
        f"**Duration:** {suite_results.duration_seconds:.2f} seconds",
        f"**Input CSV Files:** {', '.join(suite_results.input_csv_files)}",
        f"**Output Results File:** `{Path(suite_results.output_results_file).name}` (in results_data folder)",
        f"**Output Report File:** `{Path(suite_results.output_report_file).name}` (in reports folder)",
    ]

    if model_specific_config_lines:
        report_content.extend(model_specific_config_lines)

    report_content.extend([
        "\n## Overall Metrics\n",
        "| Metric                      | Value      |",
        "|-----------------------------|------------|",
        f"| Number of Samples           | {metrics.num_samples} |",
        f"| Accuracy                    | {metrics.accuracy:.4f} |",
        "| **UNSAFE Class Metrics**    |            |",
        f"| Precision (Unsafe)          | {f'{metrics.precision_unsafe:.4f}' if metrics.precision_unsafe is not None else 'N/A'} |",
        f"| Recall (Unsafe)             | {f'{metrics.recall_unsafe:.4f}' if metrics.recall_unsafe is not None else 'N/A'} |",
        f"| F1-score (Unsafe)           | {f'{metrics.f1_unsafe:.4f}' if metrics.f1_unsafe is not None else 'N/A'} |",
        f"| False Positive Rate (Unsafe)| {f'{metrics.fpr_unsafe:.4f}' if metrics.fpr_unsafe is not None else 'N/A'} |",
        f"| False Negative Rate (Unsafe)| {f'{metrics.fnr_unsafe:.4f}' if metrics.fnr_unsafe is not None else 'N/A'} |",
        "\n## Counts\n",
        "| Category                    | Count      |",
        "|-----------------------------|------------|",
        f"| True Safe                   | {metrics.num_true_safe} |",
        f"| True Unsafe                 | {metrics.num_true_unsafe} |",
        f"| Predicted Safe              | {metrics.num_pred_safe} |",
        f"| Predicted Unsafe            | {metrics.num_pred_unsafe} |",
        f"| Prediction Errors (Model)   | {metrics.num_pred_error} |",
    ])
    
    if metrics.num_pred_ambiguous is not None:
        report_content.append(f"| Predicted Ambiguous         | {metrics.num_pred_ambiguous} |")

    report_content.extend([
        "\n## Confusion Matrix (Unsafe as Positive)\n",
        "|                   | Predicted SAFE | Predicted UNSAFE |",
        "|-------------------|----------------|------------------|",
        f"| **Actual SAFE**   | {cm.get('TN',0)} (TN)     | {cm.get('FP',0)} (FP)       |",
        f"| **Actual UNSAFE** | {cm.get('FN',0)} (FN)     | {cm.get('TP',0)} (TP)       |",
        "\n---\n",
        f"**Notes:**",
        f"- Accuracy is calculated on samples where the true label was '{safe_label_str}' or '{unsafe_label_str}'. Predictions of '{error_label_str}' (or '{ambiguous_label_str}' if applicable) against these true labels are considered misclassifications for accuracy.",
        f"- Precision, Recall, F1, FPR, FNR for the '{unsafe_label_str}' class are calculated based on predictions that were either '{safe_label_str}' or '{unsafe_label_str}' and true labels that were also '{safe_label_str}' or '{unsafe_label_str}'.",
        f"- A detailed list of results for each prompt can be found in the data output file: `{Path(suite_results.output_results_file).name}`",
    ])
    
    return "\n".join(report_content)

                               
def save_json_output(data_model: BaseModel, file_path: Path, logger: logging.Logger):
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            json_data = data_model.model_dump(mode='json', exclude_none=True) 
            json.dump(json_data, f, indent=4)
        logger.info(f"Successfully saved JSON output to: {file_path}")
    except Exception as e:
        logger.error(f"Failed to save JSON output to {file_path}: {e}", exc_info=True)

def save_markdown_report(report_content: str, file_path: Path, logger: logging.Logger):
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(report_content)
        logger.info(f"Successfully saved Markdown report to: {file_path}")
    except Exception as e:
        logger.error(f"Failed to save Markdown report to {file_path}: {e}", exc_info=True)

                                    
def get_progress_bar(iterable, desc: str, total: Optional[int] = None, logger: Optional[logging.Logger] = None):
    if logger:                                              
        logger.info(f"Starting processing: {desc} (Total items: {total if total is not None else 'unknown'})")
    
    if total is None:
        try:
            total = len(iterable)
        except TypeError:
            total = None

                                                                                  
    progress_bar = tqdm(iterable, desc=desc, total=total, unit="item", 
                        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
                        position=10,                                                              
                       )
    return progress_bar
