"""
Mode 1 — dataset sanitization (model-agnostic).
Uses real CORD-19 dataset (or provided corpus) to apply tiered biosecurity keyword filtering.
Computes records removed at three filter levels (L1, L2, L3) and generates a bar chart of removal rates.
Outputs:
 - Sanitized CSV for each level (sanitized_L1_custom.csv, sanitized_L2_human.csv, sanitized_L3_all.csv).
 - mode1_sanitization_results.csv summarizing total, kept, removed, removal_rate per level.
 - A figure in the figures directory showing removal counts and percentages by level.
"""
import os
import pandas as pd
import matplotlib.pyplot as plt
from typing import List
from guard import VIRUS_FILTER_LEVELS

def _filter_by_level(df: pd.DataFrame, fields: List[str], level: str) -> pd.DataFrame:
    """Filter the DataFrame by removing any row containing level-specific risky terms."""
    terms = set(VIRUS_FILTER_LEVELS.get(level, []))
    # Determine which rows are allowed (no term present)
    mask = df.apply(lambda row: not any((k in str(row[f]).lower()) for f in fields for k in terms), axis=1)
    return df.loc[mask].copy()

def run_mode1(corpus_csv: str, out_dir: str) -> pd.DataFrame:
    """
    Run Mode 1: Apply biosecurity keyword filters at L1, L2, L3 on the corpus.
    Saves filtered datasets and summary stats, and creates bar plot visualizations.
    """
    os.makedirs(out_dir, exist_ok=True)
    # Read dataset with only relevant fields (title, abstract, plus cord_uid if present)
    # Use latin1 encoding to handle special characters in CORD-19
    raw_cols = pd.read_csv(corpus_csv, nrows=0, encoding="latin1").columns.tolist()
    fields = [f for f in ["title", "abstract"] if f in raw_cols]
    if not fields:
        raise ValueError(f"Mode1: no title/abstract fields found in {corpus_csv}. Columns available: {raw_cols}")
    usecols = fields.copy()
    if "cord_uid" in raw_cols:
        usecols.insert(0, "cord_uid")
    df = pd.read_csv(corpus_csv, usecols=usecols, encoding="latin1", low_memory=False)
    total = len(df)
    # Apply each filter level
    summary_rows = []
    removal_counts = []
    removal_pcts = []
    for lvl in ["L1_custom", "L2_human", "L3_all"]:
        kept_df = _filter_by_level(df, fields, lvl)
        kept_path = os.path.join(out_dir, f"sanitized_{lvl}.csv")
        kept_df.to_csv(kept_path, index=False)
        kept_count = len(kept_df)
        removed_count = total - kept_count
        removal_rate = (removed_count / total) * 100 if total else 0.0
        summary_rows.append({
            "level": lvl,
            "total": total,
            "kept": kept_count,
            "removed": removed_count,
            "removal_rate": round(removed_count/total if total else 0, 4)
        })
        removal_counts.append(removed_count)
        removal_pcts.append(round(removal_rate, 2))
    summary_df = pd.DataFrame(summary_rows)
    summary_path = os.path.join(out_dir, "mode1_sanitization_results.csv")
    summary_df.to_csv(summary_path, index=False)
    # Create bar plot (absolute and percentage)
    # Determine a name for dataset based on file name
    from pathlib import Path
    file_stem = Path(corpus_csv).stem
    safe_name = file_stem.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "_").lower()
    dataset_label = "CORD-19" if "cord" in safe_name else safe_name
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    levels_labels = ["Level 1", "Level 2", "Level 3"]
    colors = ['#ffcccc', '#ff6666', '#cc0000']
    # Left: absolute removed counts
    bars1 = ax1.bar(levels_labels, removal_counts, color=colors, edgecolor='black', linewidth=1)
    ax1.set_title(f"{dataset_label}: Records Removed by Level", fontsize=13, fontweight='bold')
    ax1.set_ylabel("Number of Records Removed", fontsize=11)
    ax1.set_xlabel("Biosecurity Filter Level", fontsize=11)
    for rect in bars1:
        h = rect.get_height()
        ax1.text(rect.get_x()+rect.get_width()/2, h, f'{int(h):,}', ha="center", va="bottom", fontsize=10)
    ax1.grid(True, axis='y', alpha=0.3)
    # Right: percentage removal rates
    bars2 = ax2.bar(levels_labels, removal_pcts, color=colors, edgecolor='black', linewidth=1)
    ax2.set_title(f"{dataset_label}: Removal Rate by Level", fontsize=13, fontweight='bold')
    ax2.set_ylabel("Removal Rate (%)", fontsize=11)
    ax2.set_xlabel("Biosecurity Filter Level", fontsize=11)
    ax2.set_ylim(0, 100)
    for rect in bars2:
        h = rect.get_height()
        ax2.text(rect.get_x()+rect.get_width()/2, h, f'{h:.1f}%', ha="center", va="bottom", fontsize=10)
    ax2.grid(True, axis='y', alpha=0.3)
    plt.suptitle(f"Biosecurity Filtering Analysis - {dataset_label}", fontsize=15, fontweight='bold')
    plt.tight_layout()
    # Save figure to common figures directory
    fig_dir = os.path.join(Path(out_dir).parent, "figures")
    os.makedirs(fig_dir, exist_ok=True)
    fig_path = os.path.join(fig_dir, f"mode1_{safe_name}_analysis.png")
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.close()
    return summary_df
