
from datasets import load_dataset
import pandas as pd, numpy as np
from scipy.stats import chi2_contingency
from tqdm.auto import tqdm

# ---------- 1. load the JSON Lines file ---------------------
docs = load_dataset("json", data_files="train.jsonl", split="train")

# ---------- 2. flatten to event mentions --------------------
rows = []
for d in tqdm(docs, desc="flatten"):
    did = d["id"]
    for ev in d["events"]:
        for m in ev["mention"]:
            rows.append({"doc_id": did, "code": m["factuality"].upper().strip()})

df = pd.DataFrame(rows)
print(f"{len(df):,} event mentions from {df.doc_id.nunique():,} docs")

# ---------- 3. keep only certain-true / certain-false -------
CERT_TRUE  = {"CT++", "CT+"}
CERT_FALSE = {"CT--", "CT-"}

df = df[df["code"].isin(CERT_TRUE | CERT_FALSE)].copy()
df["non_fact"] = df["code"].map(lambda c: 1 if c in CERT_FALSE else 0)

print(f"{len(df):,} certain events retained "
      f"({df['non_fact'].mean():.3%} certain-false)")

# ---------- 4. group by document (≥2 events) ---------------
groups = {k: g for k, g in df.groupby("doc_id") if len(g) > 1}

# ---------- 5. co-occurrence metrics ------------------------
p = df["non_fact"].mean()
obs_var = np.var([g.non_fact.mean() for g in groups.values()], ddof=0)
exp_var = np.mean([p * (1 - p) / len(g) for g in groups.values()])
clustering_ratio = obs_var / exp_var

choose2 = lambda n: n * (n - 1) / 2
pair_false = sum(choose2(g.non_fact.sum()) for g in groups.values())
pair_total = sum(choose2(len(g))            for g in groups.values())
p_nn, p_ind = pair_false / pair_total, p ** 2

chi2, pval, _, _ = chi2_contingency(
    pd.crosstab(df.doc_id, df.non_fact), correction=False
)

# ---------- 6. results --------------------------------------
print("\n——  Certain-false clustering ——")
print(f"Certain-false rate (p)            : {p:.4f}")
print(f"Clustering ratio (obs / exp var)  : {clustering_ratio:.3f}")
print(f"P(two certain-false | same doc)   : {p_nn:.5f}")
print(f"Expected under independence p²    : {p_ind:.5f}")
print(f"χ² = {chi2:.0f},   p-value = {pval:.3g}")
s