from pathlib import Path
import csv
import numpy as np
import pandas as pd
import shutil

SEED = 20240914  # deterministic seed


def _read_source_dataframe(src_csv: Path) -> pd.DataFrame:
    assert src_csv.exists(), f"Source CSV not found: {src_csv}"
    # Handle potential leading index column in the provided dataset
    df = pd.read_csv(src_csv)
    if "text" not in df.columns or "Emotion" not in df.columns:
        df = pd.read_csv(src_csv, index_col=0)
    assert "text" in df.columns and "Emotion" in df.columns, (
        f"Expected columns 'text' and 'Emotion' in dataset. Found: {list(df.columns)}"
    )
    df = df[["text", "Emotion"]].copy()
    df["text"] = (
        df["text"].astype(str).str.replace("\r", " ", regex=False).str.replace("\n", " ", regex=False).str.strip()
    )
    df["Emotion"] = df["Emotion"].astype(str).str.strip()
    # Drop rows with empty values
    df = df.replace({"text": {"": np.nan}, "Emotion": {"": np.nan}})
    df = df.dropna(subset=["text", "Emotion"]).reset_index(drop=True)
    return df


def _generate_ids(n: int) -> list[str]:
    return [f"E{idx:08d}" for idx in range(n)]


def _stratified_split(df: pd.DataFrame, test_frac: float = 0.2) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.RandomState(SEED)
    labels = df["Emotion"].values
    n = len(df)
    idx_all = np.arange(n)
    test_mask = np.zeros(n, dtype=bool)

    classes, counts = np.unique(labels, return_counts=True)
    for cls, cnt in zip(classes, counts):
        cls_idx = idx_all[labels == cls]
        cls_idx = cls_idx.copy()
        rng.shuffle(cls_idx)
        raw_n_test = int(round(cnt * test_frac))
        if cnt <= 1:
            n_test = 0
        else:
            n_test = min(max(1, raw_n_test), cnt - 1)
        if n_test > 0:
            test_mask[cls_idx[:n_test]] = True

    train_idx = np.sort(idx_all[~test_mask])
    test_idx = np.sort(idx_all[test_mask])

    assert len(set(train_idx).intersection(set(test_idx))) == 0
    assert len(train_idx) + len(test_idx) == n

    return train_idx, test_idx


def prepare(raw: Path, public: Path, private: Path):
    # Ensure absolute paths (as required) and create directories
    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Read raw data
    src_csv = raw / "emotion_sentimen_dataset.csv"
    df = _read_source_dataframe(src_csv)

    # Assign deterministic ids
    df = df.reset_index(drop=True)
    df.insert(0, "id", _generate_ids(len(df)))

    # Deterministic stratified split
    train_idx, test_idx = _stratified_split(df, test_frac=0.2)
    train_df = df.iloc[train_idx].reset_index(drop=True)
    test_df = df.iloc[test_idx].reset_index(drop=True)

    # Create outputs
    train_out = train_df[["id", "text", "Emotion"]].copy()
    test_out = test_df[["id", "text"]].copy()
    test_ans = test_df[["id", "Emotion"]].copy()

    # Deterministic sample submission using train label set
    label_set = sorted(train_out["Emotion"].astype(str).unique().tolist())
    rng = np.random.RandomState(SEED)
    sample_pred = rng.choice(label_set, size=len(test_out), replace=True)
    sample_sub = pd.DataFrame({"id": test_out["id"].values, "Emotion": sample_pred})

    # Save to correct locations
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_submission_csv = public / "sample_submission.csv"

    for path, df_out in [
        (train_csv, train_out),
        (test_csv, test_out),
        (test_answer_csv, test_ans),
        (sample_submission_csv, sample_sub),
    ]:
        df_out.to_csv(path, index=False, encoding="utf-8", quoting=csv.QUOTE_MINIMAL)

    # Copy description.txt into public directory (participants only see public)
    repo_desc = (raw.parent / "description.txt").resolve()
    if repo_desc.exists():
        shutil.copy(repo_desc, public / "description.txt")

    # Checks to ensure integrity
    # 1) Existence
    assert train_csv.exists() and test_csv.exists() and sample_submission_csv.exists()
    assert test_answer_csv.exists()

    # 2) Schemas
    assert list(pd.read_csv(train_csv, nrows=0).columns) == ["id", "text", "Emotion"], "Invalid train.csv schema"
    assert list(pd.read_csv(test_csv, nrows=0).columns) == ["id", "text"], "Invalid test.csv schema"
    assert list(pd.read_csv(test_answer_csv, nrows=0).columns) == ["id", "Emotion"], "Invalid test_answer.csv schema"
    assert list(pd.read_csv(sample_submission_csv, nrows=0).columns) == ["id", "Emotion"], "Invalid sample_submission.csv schema"

    # 3) Id coverage and disjointness
    train_ids = set(pd.read_csv(train_csv)["id"].tolist())
    test_ids = pd.read_csv(test_csv)["id"].tolist()
    ans_ids = set(pd.read_csv(test_answer_csv)["id"].tolist())
    sample_ids = set(pd.read_csv(sample_submission_csv)["id"].tolist())

    assert set(test_ids) == ans_ids == sample_ids, "Mismatch among test ids across files"
    assert train_ids.isdisjoint(ans_ids), "Train and test ids must be disjoint"

    # 4) No label leakage in public test.csv
    assert "Emotion" not in pd.read_csv(test_csv, nrows=1).columns

    # 5) Determinism in counts
    assert len(train_out) + len(test_out) == len(df)
