from pathlib import Path
import shutil
import json
import math
from typing import List, Tuple

import numpy as np
import pandas as pd

# Deterministic seed used for all randomized operations
RANDOM_STATE = 20240914


def _read_source_csv(path: Path) -> pd.DataFrame:
    assert path.exists(), f"Source data not found: {path}"
    # Robust CSV reading for multi-line quoted fields
    try:
        df = pd.read_csv(path, engine="c")
    except Exception:
        df = pd.read_csv(path, engine="python")

    expected_cols = {"subreddit", "body", "controversiality", "score"}
    missing = expected_cols - set(df.columns)
    assert not missing, f"Source CSV missing columns: {missing}"
    return df


def _build_id_series(n: int) -> pd.Series:
    # Deterministic sequential IDs as strings
    return pd.Series(np.arange(n, dtype=np.int64)).astype(str)


def _stratified_split(df: pd.DataFrame, label_col: str, train_frac: float = 0.8) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.RandomState(RANDOM_STATE)
    classes = sorted(df[label_col].unique().tolist())
    train_indices: List[int] = []
    test_indices: List[int] = []
    for c in classes:
        idx = df.index[df[label_col] == c].to_numpy()
        rng.shuffle(idx)
        n = len(idx)
        n_train = int(math.floor(train_frac * n))
        # Ensure at least 1 sample goes to each of train and test where possible
        if n > 1:
            if n_train <= 0:
                n_train = 1
            if n_train >= n:
                n_train = n - 1
        train_indices.extend(idx[:n_train].tolist())
        test_indices.extend(idx[n_train:].tolist())
    return np.array(train_indices, dtype=np.int64), np.array(test_indices, dtype=np.int64)


def _write_csv(df: pd.DataFrame, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(path, index=False, encoding="utf-8")


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete dataset preparation process.

    - Reads raw/kaggle_RC_2019-05.csv
    - Creates deterministic stratified train/test split
    - Writes public/train.csv, public/test.csv, public/sample_submission.csv
    - Writes private/test_answer.csv
    - Copies description.txt into public/
    """
    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    src_path = raw / "kaggle_RC_2019-05.csv"
    df = _read_source_csv(src_path)

    # Basic integrity checks
    assert df.shape[0] > 100000, "Dataset unexpectedly small; expected ~1e6 rows."
    assert df["subreddit"].notna().all(), "Null subreddit labels present."
    assert df["body"].notna().all(), "Null body present."
    unique_contr = sorted(pd.unique(df["controversiality"]))
    assert set(unique_contr).issubset({0, 1}), f"Unexpected controversiality values: {unique_contr}"

    # Create stable IDs
    df = df.copy()
    df.insert(0, "id", _build_id_series(len(df)))

    # Ensure label variety
    classes = sorted(df["subreddit"].unique().tolist())
    assert len(classes) >= 30, f"Expected many subreddits, found {len(classes)}"

    # Deterministic stratified split
    train_idx, test_idx = _stratified_split(df, label_col="subreddit", train_frac=0.8)

    # Assemble splits
    train_df = df.loc[train_idx, ["id", "body", "controversiality", "score", "subreddit"]].reset_index(drop=True)
    test_df = df.loc[test_idx, ["id", "body", "controversiality", "score"]].reset_index(drop=True)
    test_answer_df = df.loc[test_idx, ["id", "subreddit"]].reset_index(drop=True)

    # Assertions
    assert train_df["id"].is_unique and test_df["id"].is_unique, "IDs are not unique."
    assert set(train_df["id"]).isdisjoint(set(test_df["id"])), "Train/Test ID overlap detected."
    assert test_df.shape[0] == test_answer_df.shape[0], "Mismatch in test and answer rows."
    assert (test_df["id"].values == test_answer_df["id"].values).all(), "Test and answer IDs are misaligned."

    # All test labels must appear in training set
    train_labels = set(train_df["subreddit"].unique().tolist())
    test_labels = set(test_answer_df["subreddit"].unique().tolist())
    assert test_labels.issubset(train_labels), "Some test labels do not appear in training set."

    # Output paths
    train_path = public / "train.csv"
    test_path = public / "test.csv"
    test_answer_path = private / "test_answer.csv"
    sample_sub_path = public / "sample_submission.csv"

    # Write CSVs
    _write_csv(train_df, train_path)
    _write_csv(test_df, test_path)
    _write_csv(test_answer_df, test_answer_path)

    # Build sample submission
    class_names = sorted(train_df["subreddit"].unique().tolist())
    assert class_names == sorted(list(test_labels)), "Class names mismatch between train and test."

    n_test = test_df.shape[0]
    K = len(class_names)

    public.mkdir(parents=True, exist_ok=True)
    with open(sample_sub_path, "w", newline="", encoding="utf-8") as f:
        # Write header
        f.write(",".join(["id", *class_names]) + "\n")
        # Deterministic pseudo-random probabilities
        rng = np.random.RandomState(RANDOM_STATE)
        num_patterns = 128
        alpha = np.ones(K, dtype=np.float64)
        pattern_probs = rng.dirichlet(alpha, size=num_patterns)
        ids = test_df["id"].tolist()
        for i, _id in enumerate(ids):
            probs = pattern_probs[i % num_patterns]
            probs = np.clip(probs, 1e-15, 1.0)
            probs = probs / probs.sum()
            row = [_id] + [f"{p:.12g}" for p in probs]
            f.write(",".join(row) + "\n")

    # Copy description.txt into public/
    root_desc = Path(__file__).resolve().parent / "description.txt"
    if root_desc.exists():
        shutil.copy(root_desc, public / "description.txt")

    # Post-write validations (sampled)
    check_train = pd.read_csv(train_path, nrows=10)
    check_test = pd.read_csv(test_path, nrows=10)
    check_sub = pd.read_csv(sample_sub_path, nrows=10)
    assert "subreddit" in check_train.columns and "subreddit" not in check_test.columns, "Label leakage in test.csv"
    assert check_sub.shape[1] == 1 + K, "Sample submission has incorrect number of columns."

    # Metadata for tests
    meta = {
        "num_rows": int(df.shape[0]),
        "num_train": int(train_df.shape[0]),
        "num_test": int(test_df.shape[0]),
        "num_classes": int(K),
        "classes": class_names
    }
    with open((public / "_prepare_metadata.json"), "w", encoding="utf-8") as mf:
        json.dump(meta, mf, ensure_ascii=False, indent=2)

    return meta
