# datasets.py
"""
Datasets utilities for HMNS experiments.

Features
--------
- Load benchmarks from Hugging Face Datasets OR local files:
    AdvBench, HarmBench, JBB-Behaviors, StrongREJECT
- Normalize records to a common schema: {id, source, prompt, label}
- Filter to malicious/policy-violating items
- Deduplicate (exact + light normalization)
- Merge benchmarks into a single pool
- Deterministic 3-way split: analysis/dev/test (150/579/196 by default)
- JSONL save/load helpers
- Cohen's kappa utility for grader agreement

Usage (from main.py)
--------------------
    from datasets import (
        load_all_benchmarks,
        build_main_pool_and_splits,
        save_jsonl,
    )

    pool, splits = build_main_pool_and_splits(
        prefer_hf=True,          # try HF Datasets
        local_roots=None,        # or provide dict of local paths per benchmark
        seed=1234,
        sizes=(150, 579, 196)
    )
    save_jsonl("data/analysis.jsonl", splits["analysis"])
    save_jsonl("data/dev.jsonl",      splits["dev"])
    save_jsonl("data/test.jsonl",     splits["test"])
"""

from __future__ import annotations
import json
import hashlib
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

# Optional imports (guarded)
try:
    from datasets import load_dataset  # Hugging Face Datasets
    _HAS_HF = True
except Exception:
    _HAS_HF = False

# ----------------------------- Common record -----------------------------

@dataclass
class Example:
    id: str
    source: str
    prompt: str
    label: int  # 1 = malicious/policy-violating, 0 = benign

    def to_dict(self) -> Dict:
        return {"id": self.id, "source": self.source, "prompt": self.prompt, "label": self.label}


# ----------------------------- Normalization -----------------------------

_WHITESPACE = re.compile(r"\s+")
_PUNCT = re.compile(r"[^\w\s]")

def _norm_text(s: str) -> str:
    s = s.strip().lower()
    s = _WHITESPACE.sub(" ", s)
    return s

def _hash_key(s: str) -> str:
    return hashlib.sha1(_norm_text(s).encode("utf-8")).hexdigest()


# ----------------------------- Loaders (HF) ------------------------------

def _hf_or_raise():
    if not _HAS_HF:
        raise RuntimeError(
            "Hugging Face Datasets not available. Install `datasets` or use local file loaders."
        )

def load_advbench_hf(split: str = "test") -> List[Example]:
    """
    Expects a column with harmful prompts; many repos mirror AdvBench.
    We support a common layout: {'prompt': str, 'label': int or bool}
    Fallback: treat all as malicious if 'label' missing.
    """
    _hf_or_raise()
    # Pick a stable mirror you use internally; replace if needed.
    ds = load_dataset("AIBreaker/AdvBench", split=split)
    out = []
    for i, r in enumerate(ds):
        prompt = r.get("prompt") or r.get("input") or r.get("question") or ""
        lbl = r.get("label")
        if isinstance(lbl, bool):
            lbl = int(lbl)
        if lbl is None:
            lbl = 1  # default to malicious
        out.append(Example(id=f"advbench_{i}", source="AdvBench", prompt=prompt, label=int(lbl)))
    return out

def load_harmbench_hf(split: str = "test") -> List[Example]:
    _hf_or_raise()
    ds = load_dataset("microsoft/harmbench", split=split)
    out = []
    for i, r in enumerate(ds):
        prompt = r.get("prompt") or r.get("input") or ""
        cat = r.get("label") or r.get("category") or "malicious"
        lbl = 1 if str(cat).lower() in {"malicious", "unsafe", "violation", "harmful"} else 0
        out.append(Example(id=f"harmbench_{i}", source="HarmBench", prompt=prompt, label=lbl))
    return out

def load_jbb_behaviors_hf(split: str = "test") -> List[Example]:
    _hf_or_raise()
    # Replace with the JBB-Behaviors canonical HF id you use.
    ds = load_dataset("JBB-Behaviors/jbb_behaviors", split=split)
    out = []
    for i, r in enumerate(ds):
        prompt = r.get("prompt") or r.get("input") or ""
        tag = (r.get("label") or r.get("behavior") or "malicious").lower()
        lbl = 1 if "malicious" in tag or "unsafe" in tag or "violation" in tag else 0
        out.append(Example(id=f"jbb_{i}", source="JBB-Behaviors", prompt=prompt, label=lbl))
    return out

def load_strongreject_hf(split: str = "test") -> List[Example]:
    _hf_or_raise()
    ds = load_dataset("allenai/strongreject", split=split)
    out = []
    for i, r in enumerate(ds):
        prompt = r.get("prompt") or r.get("input") or ""
        cat = r.get("label") or r.get("policy") or "malicious"
        lbl = 1 if str(cat).lower() in {"malicious", "unsafe", "violation", "harmful"} else 0
        out.append(Example(id=f"strongreject_{i}", source="StrongREJECT", prompt=prompt, label=lbl))
    return out


# --------------------------- Loaders (local) -----------------------------

def _load_jsonl(path: str) -> List[Dict]:
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data.append(json.loads(line))
    return data

def load_advbench_local(path: str) -> List[Example]:
    """
    Expect JSONL with keys: prompt, (label optional).
    If label missing, defaults to 1 (malicious).
    """
    rows = _load_jsonl(path)
    out = []
    for i, r in enumerate(rows):
        prompt = r.get("prompt") or r.get("input") or ""
        lbl = r.get("label")
        if isinstance(lbl, bool):
            lbl = int(lbl)
        if lbl is None:
            lbl = 1
        out.append(Example(id=f"advbench_{i}", source="AdvBench", prompt=prompt, label=int(lbl)))
    return out

def load_harmbench_local(path: str) -> List[Example]:
    rows = _load_jsonl(path)
    out = []
    for i, r in enumerate(rows):
        prompt = r.get("prompt") or r.get("input") or ""
        cat = r.get("label") or r.get("category") or "malicious"
        lbl = 1 if str(cat).lower() in {"malicious", "unsafe", "violation", "harmful"} else 0
        out.append(Example(id=f"harmbench_{i}", source="HarmBench", prompt=prompt, label=lbl))
    return out

def load_jbb_behaviors_local(path: str) -> List[Example]:
    rows = _load_jsonl(path)
    out = []
    for i, r in enumerate(rows):
        prompt = r.get("prompt") or r.get("input") or ""
        tag = (r.get("label") or r.get("behavior") or "malicious").lower()
        lbl = 1 if "malicious" in tag or "unsafe" in tag or "violation" in tag else 0
        out.append(Example(id=f"jbb_{i}", source="JBB-Behaviors", prompt=prompt, label=lbl))
    return out

def load_strongreject_local(path: str) -> List[Example]:
    rows = _load_jsonl(path)
    out = []
    for i, r in enumerate(rows):
        prompt = r.get("prompt") or r.get("input") or ""
        cat = r.get("label") or r.get("policy") or "malicious"
        lbl = 1 if str(cat).lower() in {"malicious", "unsafe", "violation", "harmful"} else 0
        out.append(Example(id=f"strongreject_{i}", source="StrongREJECT", prompt=prompt, label=lbl))
    return out


# --------------------------- Public high-level ---------------------------

def load_all_benchmarks(
    prefer_hf: bool = True,
    local_roots: Optional[Dict[str, str]] = None,
    hf_splits: Dict[str, str] = None,
) -> Dict[str, List[Example]]:
    """
    Returns dict source -> list[Example].
    local_roots (optional) may map:
      {"advbench": "/path/adv.jsonl", "harmbench": "...", "jbb": "...", "strongreject": "..."}
    """
    if hf_splits is None:
        hf_splits = {"advbench": "test", "harmbench": "test", "jbb": "test", "strongreject": "test"}

    data = {}

    if prefer_hf and _HAS_HF:
        data["AdvBench"] = load_advbench_hf(hf_splits["advbench"])
        data["HarmBench"] = load_harmbench_hf(hf_splits["harmbench"])
        data["JBB-Behaviors"] = load_jbb_behaviors_hf(hf_splits["jbb"])
        data["StrongREJECT"] = load_strongreject_hf(hf_splits["strongreject"])
        return data

    # Fall back to local files
    if not local_roots:
        raise RuntimeError("No HF access and no local_roots provided.")

    data["AdvBench"] = load_advbench_local(local_roots["advbench"])
    data["HarmBench"] = load_harmbench_local(local_roots["harmbench"])
    data["JBB-Behaviors"] = load_jbb_behaviors_local(local_roots["jbb"])
    data["StrongREJECT"] = load_strongreject_local(local_roots["strongreject"])
    return data


def filter_malicious(examples: Iterable[Example]) -> List[Example]:
    return [e for e in examples if int(e.label) == 1]


def dedupe_examples(examples: Iterable[Example]) -> List[Example]:
    """Exact + normalized dedup by SHA1 of lowercased, compact whitespace prompt."""
    seen = set()
    out: List[Example] = []
    for e in examples:
        key = _hash_key(e.prompt)
        if key in seen:
            continue
        seen.add(key)
        out.append(e)
    return out


def merge_benchmarks(dsets: Dict[str, List[Example]], only_malicious: bool = True) -> List[Example]:
    all_items: List[Example] = []
    for src, exs in dsets.items():
        exs2 = filter_malicious(exs) if only_malicious else list(exs)
        all_items.extend(exs2)
    return dedupe_examples(all_items)


def split_analysis_dev_test(
    items: Sequence[Example],
    sizes: Tuple[int, int, int] = (150, 579, 196),
    seed: int = 0,
) -> Dict[str, List[Example]]:
    """Deterministic split by hashing with a seed prefix."""
    a, d, t = sizes
    # Score each item by a seed-hash for reproducible shuffling
    def score(ex: Example) -> int:
        h = hashlib.sha1(f"{seed}::{ex.id}::{_hash_key(ex.prompt)}".encode("utf-8")).hexdigest()
        return int(h, 16)

    ranked = sorted(items, key=score)
    n = len(ranked)
    if a + d + t > n:
        raise ValueError(f"Requested split ({a+d+t}) exceeds dataset size ({n}).")

    return {
        "analysis": ranked[:a],
        "dev": ranked[a:a + d],
        "test": ranked[a + d:a + d + t],
    }


def build_main_pool_and_splits(
    prefer_hf: bool = True,
    local_roots: Optional[Dict[str, str]] = None,
    hf_splits: Optional[Dict[str, str]] = None,
    sizes: Tuple[int, int, int] = (150, 579, 196),
    seed: int = 0,
) -> Tuple[List[Example], Dict[str, List[Example]]]:
    """
    1) Load all benchmarks
    2) Keep malicious/policy-violating items
    3) Deduplicate and merge → main pool
    4) Deterministic analysis/dev/test split
    """
    bench = load_all_benchmarks(prefer_hf=prefer_hf, local_roots=local_roots, hf_splits=hf_splits)
    pool = merge_benchmarks(bench, only_malicious=True)
    splits = split_analysis_dev_test(pool, sizes=sizes, seed=seed)
    return pool, splits


# ---------------------------- JSONL helpers -----------------------------

def save_jsonl(path: str, rows: Iterable[Example | Dict]):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            obj = r.to_dict() if isinstance(r, Example) else r
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def load_jsonl_examples(path: str, default_source: str = "custom") -> List[Example]:
    data = _load_jsonl(path)
    out: List[Example] = []
    for i, r in enumerate(data):
        prompt = r.get("prompt") or r.get("input") or ""
        lbl = r.get("label")
        if isinstance(lbl, bool):
            lbl = int(lbl)
        if lbl is None:
            lbl = 1
        src = r.get("source") or default_source
        _id = r.get("id") or f"{src}_{i}"
        out.append(Example(id=_id, source=src, prompt=prompt, label=int(lbl)))
    return out


# ------------------------- Agreement (optional) -------------------------

def cohens_kappa(y_a: Sequence[int], y_b: Sequence[int]) -> float:
    """
    Cohen's kappa for two binary labelers: y in {0,1}.
    """
    if len(y_a) != len(y_b):
        raise ValueError("Label sequences must have the same length.")
    n = len(y_a)
    if n == 0:
        return 0.0
    agree = sum(1 for a, b in zip(y_a, y_b) if int(a) == int(b)) / n
    pa = sum(int(v) for v in y_a) / n
    pb = sum(int(v) for v in y_b) / n
    pe = pa * pb + (1 - pa) * (1 - pb)
    if pe == 1.0:
        return 1.0
    return (agree - pe) / (1 - pe + 1e-12)
