from pathlib import Path
import shutil
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit

# Constants for schema
ID_COL = "Id"
LABEL_COL = "Y"
CLASS_COLUMNS = ["HQ", "LQ_EDIT", "LQ_CLOSE"]
TEXT_COLUMNS = ["Title", "Body", "Tags", "CreationDate"]
RANDOM_SEED = 2025


def _read_sources(raw: Path) -> pd.DataFrame:
    """Read one or more labeled CSVs from raw/ and concatenate deterministically.

    Expected columns at least: Id, Title, Body, Tags, CreationDate, Y
    """
    assert raw.exists(), f"Raw directory does not exist: {raw}"

    paths = []
    for name in ("train.csv", "valid.csv"):
        p = raw / name
        if p.is_file():
            paths.append(p)
    if not paths:
        raise FileNotFoundError(
            f"No source CSVs found in raw directory: {raw}. Expected at least train.csv and/or valid.csv"
        )

    frames = []
    for p in paths:
        df = pd.read_csv(p)
        frames.append(df)

    data = pd.concat(frames, axis=0, ignore_index=True)

    # Validate required columns
    expected = [ID_COL, *TEXT_COLUMNS, LABEL_COL]
    missing = [c for c in expected if c not in data.columns]
    if missing:
        raise ValueError(f"Missing required columns from raw CSVs: {missing}")

    # Normalize dtypes and order
    data = data[[ID_COL, *TEXT_COLUMNS, LABEL_COL]].copy()
    data[ID_COL] = pd.to_numeric(data[ID_COL], errors="raise")
    for c in TEXT_COLUMNS:
        data[c] = data[c].astype(str)
    data[LABEL_COL] = data[LABEL_COL].astype(str)

    # Clean duplicates and sort by Id for determinism
    data = data.drop_duplicates(subset=[ID_COL], keep="first").sort_values(ID_COL).reset_index(drop=True)

    # Validate labels
    labels = set(data[LABEL_COL].unique().tolist())
    allowed = set(CLASS_COLUMNS)
    if not labels.issubset(allowed):
        raise ValueError(
            f"Unexpected labels in raw data: {sorted(labels - allowed)}. Allowed = {sorted(allowed)}"
        )

    return data


def _deterministic_stratified_split(df: pd.DataFrame, test_size: float = 0.25):
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=RANDOM_SEED)
    y = df[LABEL_COL].values
    idx_train, idx_test = next(splitter.split(df, y))
    train_df = df.iloc[idx_train].copy().reset_index(drop=True)
    test_df = df.iloc[idx_test].copy().reset_index(drop=True)

    # Integrity checks
    assert train_df[ID_COL].is_unique and test_df[ID_COL].is_unique
    assert set(train_df[ID_COL]).isdisjoint(set(test_df[ID_COL]))
    # Ensure classes present in both splits
    assert set(train_df[LABEL_COL].unique()).issuperset(set(CLASS_COLUMNS))
    assert set(test_df[LABEL_COL].unique()).issuperset(set(CLASS_COLUMNS))

    return train_df, test_df


def _write_public_private(train_df: pd.DataFrame, test_df: pd.DataFrame, public: Path, private: Path):
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Public files
    train_out = public / "train.csv"
    test_out = public / "test.csv"
    sub_out = public / "sample_submission.csv"

    # Private file
    ans_out = private / "test_answer.csv"

    # Write train
    train_df.to_csv(train_out, index=False)

    # Write public test (without labels)
    test_public = test_df[[ID_COL, *TEXT_COLUMNS]].copy()
    test_public.to_csv(test_out, index=False)

    # Write private answers
    test_df[[ID_COL, LABEL_COL]].to_csv(ans_out, index=False)

    # Deterministic sample submission: Dirichlet probabilities per row
    rng = np.random.default_rng(RANDOM_SEED)
    probs = rng.dirichlet(alpha=np.ones(len(CLASS_COLUMNS)), size=len(test_public))
    sample = pd.DataFrame({ID_COL: test_public[ID_COL].tolist()})
    for j, cls in enumerate(CLASS_COLUMNS):
        sample[cls] = probs[:, j]
    # Validate row sums
    row_sums = sample[CLASS_COLUMNS].sum(axis=1).to_numpy()
    assert np.all(np.isfinite(row_sums)) and np.all(np.abs(row_sums - 1.0) < 1e-8)
    sample.to_csv(sub_out, index=False)

    # Copy description.txt into public for participants
    root_description = Path(__file__).resolve().parent / "description.txt"
    if root_description.is_file():
        shutil.copy(str(root_description), str(public / "description.txt"))

    # Post-write validation
    train_r = pd.read_csv(train_out)
    test_r = pd.read_csv(test_out)
    ans_r = pd.read_csv(ans_out)
    sub_r = pd.read_csv(sub_out)

    assert list(train_r.columns) == [ID_COL, *TEXT_COLUMNS, LABEL_COL]
    assert list(test_r.columns) == [ID_COL, *TEXT_COLUMNS]
    assert list(sub_r.columns) == [ID_COL, *CLASS_COLUMNS]
    assert LABEL_COL not in test_r.columns
    assert set(test_r[ID_COL]) == set(ans_r[ID_COL]) and len(test_r) == len(ans_r)
    assert set(train_r[ID_COL]).isdisjoint(set(test_r[ID_COL]))


def prepare(raw: Path, public: Path, private: Path):
    """Complete deterministic preparation process.

    - Read labeled CSVs from raw/
    - Stratified deterministic split into public/train.csv and public/test.csv
    - Write private/test_answer.csv containing [Id, Y]
    - Write public/sample_submission.csv aligned with public/test.csv
    - Copy description.txt into public/
    """
    # Read raw data
    df = _read_sources(raw)

    # Deterministic stratified split
    train_df, test_df = _deterministic_stratified_split(df, test_size=0.25)

    # Write outputs to public/ and private/
    _write_public_private(train_df, test_df, public, private)
