import pandas as pd
import numpy as np
from typing import Optional, Sequence

def attach_cost_columns(df: pd.DataFrame, eps: float = 1.0, tie_positive: bool = True) -> pd.DataFrame:
    if not eps > 0:
        raise ValueError("Epsilon must be positive")
    num = (df["n_yes"].astype("float32") + eps)
    den = (df["n_no"].astype("float32") + eps)
    d = np.log(num / den).astype("float32")
    return df.assign(
        delta_signed=d,
        abs_delta=np.abs(d),
        y_star=((df["n_yes"] > df["n_no"]) | ((df["n_yes"] == df["n_no"]) & tie_positive)).astype(np.int8),
    )


def build_jigsaw_frame(df_raw: pd.DataFrame, eps: float = 1.0, tie_positive: bool = True) -> pd.DataFrame:
    required_columns = ["comment_text", "target", "toxicity_annotator_count"]
    missing_columns = [c for c in required_columns if c not in df_raw.columns]
    if missing_columns:
        raise ValueError(f"Input DataFrame is missing required columns: {missing_columns}")

    # existing NA drop for numeric fields
    df = df_raw.dropna(subset=["target", "toxicity_annotator_count"]).copy()
    df = df[df["toxicity_annotator_count"] > 0]

    # NEW: normalize and filter texts
    # - coerce to string (so vectorizers won't see NaN)
    # - drop rows where text is NA/empty after strip
    df["comment_text"] = df["comment_text"].astype("string")
    text_mask = df["comment_text"].notna() & (df["comment_text"].str.strip() != "")
    df = df.loc[text_mask].copy()

    # existing numeric coercions
    df["toxicity_annotator_count"] = df["toxicity_annotator_count"].astype("int32")
    df["target"] = df["target"].astype("float32").clip(0.0, 1.0)

    # rebuild votes
    df["n_yes"] = np.floor(df["target"] * df["toxicity_annotator_count"] + 0.5).astype("int32")
    df["n_yes"] = df["n_yes"].clip(lower=0, upper=df["toxicity_annotator_count"])
    df["n_no"] = (df["toxicity_annotator_count"] - df["n_yes"]).astype("int32")
    if not (df["n_yes"] + df["n_no"] == df["toxicity_annotator_count"]).all():
        raise ValueError("Invariant failed: number of annotators mismatch")

    df = attach_cost_columns(df, eps, tie_positive)

    # ensure final types are as expected (helps invariants)
    out = df[["comment_text", "n_yes", "n_no", "delta_signed", "abs_delta", "y_star"]].copy()
    out["n_yes"] = out["n_yes"].astype("int32")
    out["n_no"] = out["n_no"].astype("int32")
    out["y_star"] = out["y_star"].astype("int8")
    out["delta_signed"] = out["delta_signed"].astype("float32")
    out["abs_delta"] = out["abs_delta"].astype("float32")
    return out


def load_jigsaw_frame(
    path: str = "data/jigsaw/train.csv",
    eps: float = 1.0,
    tie_positive: bool = True,
    n_rows: Optional[int] = None,
    sample: Optional[int] = None,
    random_state: int = 42,
    usecols: Optional[Sequence[str]] = ("comment_text", "target", "toxicity_annotator_count"),
) -> pd.DataFrame:
    df_raw = pd.read_csv(path, nrows=n_rows)
    # filter texts before sampling
    df_raw["comment_text"] = df_raw["comment_text"].astype("string")
    df_raw = df_raw[df_raw["comment_text"].notna() & (df_raw["comment_text"].str.strip() != "")]
    if sample is not None:
        df_raw = df_raw.sample(n=sample, random_state=random_state)
    return build_jigsaw_frame(df_raw=df_raw, eps=eps, tie_positive=tie_positive)
