from pathlib import Path
import shutil
import math
import random
import pandas as pd


def _anonymize_ids(n: int) -> list[str]:
    return [f"LCC{idx:06d}" for idx in range(1, n + 1)]


def _deterministic_group_shuffle(df: pd.DataFrame, by: str, seed: int) -> pd.DataFrame:
    rng = random.Random(seed)
    parts: list[pd.DataFrame] = []
    for _, grp in df.groupby(by, sort=True):
        idxs = list(grp.index)
        rng.shuffle(idxs)
        parts.append(df.loc[idxs])
    return pd.concat(parts, axis=0)


def _stratified_split(
    df: pd.DataFrame,
    label_col: str,
    test_fraction: float,
    seed: int,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    assert 0.0 < test_fraction < 1.0

    df = _deterministic_group_shuffle(df.copy(), by=label_col, seed=seed)

    test_idx: list[int] = []
    train_idx: list[int] = []

    for _, grp in df.groupby(label_col, sort=True):
        n = len(grp)
        if n <= 1:
            train_idx.extend(list(grp.index))
            continue
        k = max(1, int(math.floor(test_fraction * n)))
        k = min(k, n - 1)
        ids = list(grp.index)
        test_idx.extend(ids[:k])
        train_idx.extend(ids[k:])

    train_df = df.loc[sorted(train_idx)].copy()
    test_df = df.loc[sorted(test_idx)].copy()

    # Safety checks
    assert set(train_df.index).isdisjoint(set(test_df.index))
    assert len(train_df) + len(test_df) == len(df)

    # Ensure all test labels appear in train
    test_labels = set(test_df[label_col].unique().tolist())
    train_labels = set(train_df[label_col].unique().tolist())
    assert test_labels.issubset(train_labels)

    return train_df, test_df


def prepare(raw: Path, public: Path, private: Path):
    # Create directories
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Load raw CSV
    raw_csv = raw / "legal_text_classification.csv"
    assert raw_csv.exists(), f"Raw data file not found: {raw_csv}"
    df = pd.read_csv(raw_csv)

    required_cols = ["case_id", "case_outcome", "case_title", "case_text"]
    missing = [c for c in required_cols if c not in df.columns]
    assert not missing, f"Missing required columns: {missing}"
    assert len(df) > 0

    # Add anonymized ids and reorder
    df = df.copy()
    df.insert(0, "id", _anonymize_ids(len(df)))
    df = df[["id", "case_title", "case_text", "case_outcome", "case_id"]]

    # Stratified deterministic split
    SEED = 12345
    TEST_FRACTION = 0.2
    train_raw, test_raw = _stratified_split(df, label_col="case_outcome", test_fraction=TEST_FRACTION, seed=SEED)

    # Build outputs
    train_out = train_raw[["id", "case_title", "case_text", "case_outcome"]].copy()
    test_out = test_raw[["id", "case_title", "case_text"]].copy()
    test_answer_out = test_raw[["id", "case_outcome"]].copy()

    # Integrity checks
    assert train_out["id"].is_unique and test_out["id"].is_unique and test_answer_out["id"].is_unique
    assert set(test_out["id"]) == set(test_answer_out["id"])  # alignment
    assert set(train_out["id"]).isdisjoint(set(test_out["id"]))

    # Write CSVs to required locations
    (public / "train.csv").write_text(train_out.to_csv(index=False))
    (public / "test.csv").write_text(test_out.to_csv(index=False))
    (private / "test_answer.csv").write_text(test_answer_out.to_csv(index=False))

    # Sample submission: deterministic labels chosen from training label space
    labels = sorted(train_out["case_outcome"].astype(str).unique().tolist())
    assert len(labels) > 0
    rng = random.Random(SEED)
    sample_submission = pd.DataFrame(
        {
            "id": test_out["id"].tolist(),
            "case_outcome": [rng.choice(labels) for _ in range(len(test_out))],
        }
    )
    (public / "sample_submission.csv").write_text(sample_submission.to_csv(index=False))

    # Copy description.txt to public/ if available at root
    root_description = raw.parent / "description.txt"
    if root_description.exists():
        shutil.copy(root_description, public / "description.txt")

    # Post checks
    assert (public / "train.csv").exists()
    assert (public / "test.csv").exists()
    assert (public / "sample_submission.csv").exists()
    assert (private / "test_answer.csv").exists()

    # No leakage: public/test.csv must not contain ground-truth column
    assert "case_outcome" not in pd.read_csv(public / "test.csv").columns

    # Ensure alignment between test and sample submission ids
    test_df = pd.read_csv(public / "test.csv")
    sub_df = pd.read_csv(public / "sample_submission.csv")
    assert sorted(test_df["id"].tolist()) == sorted(sub_df["id"].tolist())
