# -*- coding: utf-8 -*-
"""
Mode 3 — guard evaluation across levels (L1/L2/L3), per model.
Computes precision/recall/F1/FPR/JSR. Outputs combined CSV across models.
"""
import os, pandas as pd
from typing import Dict, Any, List
from guard import biosecurity_alignment_guard

def _metrics_of(y_true: List[int], y_pred: List[int]) -> Dict[str, Any]:
    tp = sum(1 for t,p in zip(y_true,y_pred) if t==1 and p==1)
    tn = sum(1 for t,p in zip(y_true,y_pred) if t==0 and p==0)
    fp = sum(1 for t,p in zip(y_true,y_pred) if t==0 and p==1)
    fn = sum(1 for t,p in zip(y_true,y_pred) if t==1 and p==0)
    precision = tp/(tp+fp) if (tp+fp)>0 else 0.0
    recall    = tp/(tp+fn) if (tp+fn)>0 else 0.0
    f1 = 2*precision*recall/(precision+recall) if (precision+recall)>0 else 0.0
    fpr = fp/(fp+tn) if (fp+tn)>0 else 0.0
    jsr = fn/(tp+fn) if (tp+fn)>0 else 0.0
    return dict(tp=tp, tn=tn, fp=fp, fn=fn,
                precision=round(precision,4),
                recall=round(recall,4),
                f1=round(f1,4),
                fpr=round(fpr,4),
                jsr=round(jsr,4))

def run_mode3_for_model(model_id: str,
                        guard_levels: List[str],
                        prompts_csv: str,
                        out_dir: str,
                        use_semantic: bool = False) -> pd.DataFrame:
    os.makedirs(out_dir, exist_ok=True)
    df = pd.read_csv(prompts_csv, engine='python')
    rows = []
    for lvl in guard_levels:
        y_true, y_pred = [], []
        for _, r in df.iterrows():
            label = str(r["label"]).strip().lower()
            true_harmful = 1 if label=="harmful" else 0
            g = biosecurity_alignment_guard(
                text=str(r["prompt"]), where="pre", mode="strict",
                virus_filter_level=lvl, use_semantic=use_semantic, use_blast=False
            )
            pred_harmful = 1 if g["action"]=="block" else 0
            y_true.append(true_harmful); y_pred.append(pred_harmful)
        m = _metrics_of(y_true, y_pred)
        rows.append(dict(model_id=model_id, level=lvl, **m))
    df_out = pd.DataFrame(rows)
    path = os.path.join(out_dir, f"{model_id}_mode3_guard_metrics.csv")
    df_out.to_csv(path, index=False)
    return df_out

def aggregate_mode3(all_per_model: List[pd.DataFrame], out_dir: str) -> str:
    df = pd.concat(all_per_model, ignore_index=True)
    path = os.path.join(out_dir, "mode3_guard_metrics_summary.csv")
    df.to_csv(path, index=False)
    return path
