import shutil
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd

# Deterministic seed
RANDOM_SEED = 20240914

# Source file inside raw/
DATA_SOURCE = "all_disciplines_combined.csv"

# Public/private filenames
TRAIN_FILE = "train.csv"
TEST_FILE = "test.csv"
TEST_ANSWER_FILE = "test_answer.csv"
SAMPLE_SUB_FILE = "sample_submission.csv"
DESCRIPTION_FILE = "description.txt"

# Columns to remove from public files (to avoid leakage)
DROP_COLS = ["rank", "mark", "mark_numeric", "position"]

# Expected columns present in the source file (allow extras)
EXPECTED_SOURCE_COLS = [
    "rank",
    "mark",
    "wind",
    "competitor",
    "dob",
    "nationality",
    "position",
    "venue",
    "date",
    "result_score",
    "discipline",
    "type",
    "sex",
    "age_cat",
    "normalized_discipline",
    "track_field",
    "mark_numeric",
    "venue_country",
    "age_at_event",
    "season",
]


def _read_source(src_csv: Path) -> pd.DataFrame:
    assert src_csv.exists(), f"Source data not found: {src_csv}"
    df = pd.read_csv(src_csv, low_memory=False)
    for c in EXPECTED_SOURCE_COLS:
        assert c in df.columns, f"Missing required column in source: {c}"

    # Normalize numeric types
    df["result_score"] = pd.to_numeric(df["result_score"], errors="coerce")
    df["wind"] = pd.to_numeric(df["wind"], errors="coerce")
    df["age_at_event"] = pd.to_numeric(df["age_at_event"], errors="coerce")
    # season may be missing for some rows; keep as integer nullable
    if "season" in df.columns:
        df["season"] = pd.to_numeric(df["season"], errors="coerce").astype("Int64")

    # Keep only rows with a valid target and essential grouping attributes
    df = df[df["result_score"].apply(lambda x: np.isfinite(x))]
    df = df[
        df["normalized_discipline"].notna()
        & df["sex"].notna()
        & df["age_cat"].notna()
    ]
    assert len(df) > 0, "No valid rows after filtering."

    # Clip target for robustness
    df["result_score"] = df["result_score"].clip(lower=0, upper=2000)
    return df


def _deterministic_split(df: pd.DataFrame, test_frac: float = 0.2) -> Tuple[pd.DataFrame, pd.DataFrame]:
    rng = np.random.RandomState(RANDOM_SEED)

    # Stratify by core groups
    grp = (
        df["normalized_discipline"].astype(str)
        + "|"
        + df["sex"].astype(str)
        + "|"
        + df["age_cat"].astype(str)
    )
    df = df.copy()
    df["__grp__"] = grp

    # Split by unique units to avoid leakage across train/test
    sig_cols = [
        "competitor",
        "dob",
        "nationality",
        "venue",
        "date",
        "normalized_discipline",
        "sex",
        "age_cat",
        "type",
        "track_field",
        "venue_country",
        "season",
    ]
    # Use full, unhashed signature to avoid collisions across groups
    df["__sig__"] = df[sig_cols].fillna("").astype(str).agg("|".join, axis=1)

    train_units = []
    test_units = []

    for _, gdf in df.groupby("__grp__"):
        units = gdf["__sig__"].drop_duplicates().tolist()
        rng.shuffle(units)
        n = len(units)
        if n <= 3:
            train_units.extend(units)
            continue
        k = max(1, int(round(test_frac * n)))
        k = min(k, n - 1)
        test_units.extend(units[:k])
        train_units.extend(units[k:])

    train_units_set = set(train_units)
    test_units_set = set(test_units)
    # Ensure no overlap between unit sets
    assert train_units_set.isdisjoint(test_units_set), "Internal error: overlapping units between train and test."

    train_mask = df["__sig__"].isin(train_units_set)
    test_mask = df["__sig__"].isin(test_units_set)

    train_df = df[train_mask].drop(columns=["__grp__", "__sig__"]).copy()
    test_df = df[test_mask].drop(columns=["__grp__", "__sig__"]).copy()

    # Sanity
    assert len(train_df) > 0 and len(test_df) > 0
    assert len(train_df) + len(test_df) == len(df), "Split lost or duplicated rows."

    # Ensure group coverage
    tr_g = set(
        train_df[["normalized_discipline", "sex", "age_cat"]]
        .astype(str)
        .agg("|".join, axis=1)
        .unique()
        .tolist()
    )
    te_g = set(
        test_df[["normalized_discipline", "sex", "age_cat"]]
        .astype(str)
        .agg("|".join, axis=1)
        .unique()
        .tolist()
    )
    assert te_g.issubset(tr_g), "Every test group must appear in train"

    return train_df, test_df


def _build_and_save(train_df: pd.DataFrame, test_df: pd.DataFrame, public: Path, private: Path, project_root: Path) -> None:
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Deterministic shuffle and id assignment
    train_df = train_df.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)
    test_df = test_df.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)

    n_train = len(train_df)
    n_test = len(test_df)
    train_df.insert(0, "id", np.arange(0, n_train, dtype=np.int64))
    test_df.insert(0, "id", np.arange(n_train, n_train + n_test, dtype=np.int64))

    # Drop leakage columns
    drop_cols = [c for c in DROP_COLS if c in train_df.columns]
    public_train = train_df.drop(columns=[c for c in drop_cols if c in train_df.columns])
    public_test = test_df.drop(columns=[c for c in drop_cols if c in test_df.columns])

    # Ensure target stays only in train
    assert "result_score" in public_train.columns
    if "result_score" in public_test.columns:
        public_test = public_test.drop(columns=["result_score"])

    # Align column order
    feature_cols_test = [c for c in public_test.columns if c != "id"]
    feature_cols_test_sorted = ["id"] + sorted(feature_cols_test)
    public_test = public_test[feature_cols_test_sorted]

    feature_cols_train = [c for c in public_train.columns if c not in ["id", "result_score"]]
    public_train = public_train[["id"] + sorted(feature_cols_train) + ["result_score"]]

    # Build answer and sample submission
    test_answer = test_df[["id", "result_score"]].copy()

    mu = float(public_train["result_score"].mean())
    sd = float(public_train["result_score"].std(ddof=0))
    if not np.isfinite(sd) or sd <= 1e-6:
        sd = max(50.0, abs(mu) * 0.1)
    rng = np.random.RandomState(RANDOM_SEED)
    preds = np.clip(mu + sd * rng.normal(size=len(public_test)), 0.0, 2000.0)
    sample_sub = pd.DataFrame({"id": public_test["id"], "result_score": preds.astype(float)})

    # Save to correct locations
    public_train.to_csv(public / TRAIN_FILE, index=False)
    public_test.to_csv(public / TEST_FILE, index=False)
    test_answer.to_csv(private / TEST_ANSWER_FILE, index=False)
    sample_sub.to_csv(public / SAMPLE_SUB_FILE, index=False)

    # Copy description.txt into public
    desc_src = project_root / DESCRIPTION_FILE
    if desc_src.exists():
        shutil.copy(desc_src, public / DESCRIPTION_FILE)

    # Integrity checks
    # Basic ids
    assert (public / TRAIN_FILE).exists() and (public / TEST_FILE).exists()
    assert (public / SAMPLE_SUB_FILE).exists() and (private / TEST_ANSWER_FILE).exists()

    pub_train = pd.read_csv(public / TRAIN_FILE)
    pub_test = pd.read_csv(public / TEST_FILE)
    ans = pd.read_csv(private / TEST_ANSWER_FILE)

    assert pub_train["id"].is_unique and pub_test["id"].is_unique and ans["id"].is_unique
    assert set(pub_test["id"]) == set(ans["id"]) == set(sample_sub["id"])  # id sets match

    # No target in test
    assert "result_score" not in pub_test.columns

    # Feature alignment
    tr_feats = [c for c in pub_train.columns if c != "result_score"]
    te_feats = list(pub_test.columns)
    assert tr_feats == te_feats, "Train and test features must align (including id as first column)."

    # Group coverage
    grp_cols = ["normalized_discipline", "sex", "age_cat"]
    train_grps = set(pub_train[grp_cols].astype(str).agg("|".join, axis=1).unique())
    test_grps = set(pub_test[grp_cols].astype(str).agg("|".join, axis=1).unique())
    assert test_grps.issubset(train_grps)


def prepare(raw: Path, public: Path, private: Path):
    """Complete data preparation process.

    - Reads the raw all_disciplines_combined.csv from `raw/`
    - Deterministically splits into train and test with group-aware logic
    - Writes public/train.csv, public/test.csv, public/sample_submission.csv
    - Writes private/test_answer.csv
    - Copies description.txt to public/
    """
    src_csv = raw / DATA_SOURCE
    df = _read_source(src_csv)
    train_df, test_df = _deterministic_split(df, test_frac=0.2)
    project_root = Path.cwd()
    _build_and_save(train_df, test_df, public=public, private=private, project_root=project_root)
