from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

# Deterministic seed for all operations
RANDOM_SEED = 20250914


def _read_source_df(raw: Path) -> pd.DataFrame:
    src_path = raw / "DisneylandReviews.csv"
    assert src_path.exists(), f"Source data not found at: {src_path}"

    # Robust CSV load
    df = pd.read_csv(src_path, encoding="latin-1")

    # Expected columns from the raw file
    expected_cols = [
        "Review_ID",
        "Rating",
        "Year_Month",
        "Reviewer_Location",
        "Review_Text",
        "Branch",
    ]
    missing = [c for c in expected_cols if c not in df.columns]
    assert not missing, (
        f"Missing expected columns in source CSV: {missing}. Found: {list(df.columns)}"
    )

    # Enforce types and sanitize
    df = df[expected_cols].copy()
    df["Rating"] = pd.to_numeric(df["Rating"], errors="raise").astype(int)
    df["Year_Month"] = df["Year_Month"].astype(str)
    df["Reviewer_Location"] = df["Reviewer_Location"].astype(str)
    df["Review_Text"] = df["Review_Text"].astype(str)
    df["Branch"] = df["Branch"].astype(str)

    # Construct deterministic unique IDs to avoid any potential duplicates in raw
    df.insert(0, "TmpRow", np.arange(len(df), dtype=np.int64))
    df["Review_ID_new"] = df["TmpRow"] + 10_000_000
    df.drop(columns=["TmpRow", "Review_ID"], inplace=True)
    df.rename(columns={"Review_ID_new": "Review_ID"}, inplace=True)
    assert df["Review_ID"].is_unique, "Constructed Review_ID must be unique"

    return df[[
        "Review_ID",
        "Rating",
        "Year_Month",
        "Reviewer_Location",
        "Review_Text",
        "Branch",
    ]]


def _stratified_split(df: pd.DataFrame, test_size: float, seed: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    # Use sklearn's deterministic stratified split
    train_idx, test_idx = train_test_split(
        np.arange(len(df)),
        test_size=test_size,
        random_state=seed,
        stratify=df["Rating"],
        shuffle=True,
    )
    train_df = df.iloc[train_idx].sample(frac=1.0, random_state=seed).reset_index(drop=True)
    test_df = df.iloc[test_idx].sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return train_df, test_df


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare competition artifacts.

    Inputs (absolute paths expected):
    - raw: directory containing the original DisneylandReviews.csv
    - public: directory to place all files visible to participants
    - private: directory to place the hidden ground-truth (test_answer.csv)

    Outputs created:
    - public/train.csv
    - public/test.csv
    - public/sample_submission.csv
    - public/description.txt (copied from repository root if present)
    - private/test_answer.csv
    """
    # Ensure absolute paths are provided
    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Load and validate source
    df = _read_source_df(raw)

    # Sanity checks on the source
    assert df["Rating"].between(1, 5).all(), "All ratings must be within [1, 5]"
    assert not df.isnull().any().any(), "Source dataset contains nulls in essential fields"

    # Split deterministically (20% test)
    train_df, test_df = _stratified_split(df, test_size=0.20, seed=RANDOM_SEED)

    # Ensure integrity
    assert set(train_df.columns) == set(df.columns)
    assert set(test_df.columns) == set(df.columns)
    assert set(train_df["Review_ID"]).isdisjoint(set(test_df["Review_ID"]))
    assert set(train_df["Rating"].unique()).issuperset({1, 2, 3, 4, 5}), "Training must have all classes"
    assert set(test_df["Rating"].unique()).issubset(set(train_df["Rating"].unique()))

    # Write artifacts
    train_out = public / "train.csv"
    test_out = public / "test.csv"
    test_ans_out = private / "test_answer.csv"
    sample_sub_out = public / "sample_submission.csv"

    # Train with labels
    train_df.to_csv(train_out, index=False)

    # Public test without label
    test_df.drop(columns=["Rating"]).to_csv(test_out, index=False)

    # Private answers
    test_df[["Review_ID", "Rating"]].to_csv(test_ans_out, index=False)

    # Sample submission using valid label set from test answers to avoid class leakage issues
    rs = np.random.RandomState(RANDOM_SEED)
    valid_labels = np.sort(test_df["Rating"].unique())
    rnd = valid_labels[rs.randint(0, len(valid_labels), size=len(test_df))]
    sample_sub = pd.DataFrame({
        "Review_ID": test_df["Review_ID"].values,
        "Rating": rnd.astype(int),
    })
    sample_sub.to_csv(sample_sub_out, index=False)

    # Copy description.txt to public if present at repo root
    repo_desc = (public.parent / "description.txt").resolve()
    if repo_desc.exists():
        (public / "description.txt").write_text(repo_desc.read_text(encoding="utf-8"), encoding="utf-8")

    # Post-write checks
    # Column headers
    assert (
        pd.read_csv(train_out, nrows=0).columns.tolist()
        == ["Review_ID", "Rating", "Year_Month", "Reviewer_Location", "Review_Text", "Branch"]
    )
    assert (
        pd.read_csv(test_out, nrows=0).columns.tolist()
        == ["Review_ID", "Year_Month", "Reviewer_Location", "Review_Text", "Branch"]
    )
    assert pd.read_csv(test_ans_out, nrows=0).columns.tolist() == ["Review_ID", "Rating"]
    assert pd.read_csv(sample_sub_out, nrows=0).columns.tolist() == ["Review_ID", "Rating"]

    # Alignment of IDs across test, answers and sample submission
    test_ids = pd.read_csv(test_out)["Review_ID"].tolist()
    ans_ids = pd.read_csv(test_ans_out)["Review_ID"].tolist()
    sub_ids = pd.read_csv(sample_sub_out)["Review_ID"].tolist()
    assert test_ids == ans_ids == sub_ids, "Mismatch of Review_ID order among outputs"

    # No labels leaked into public/test.csv
    assert "Rating" not in pd.read_csv(test_out, nrows=0).columns

    # Rating validity in sample submission
    sub_ratings = pd.read_csv(sample_sub_out)["Rating"]
    assert sub_ratings.astype(int).between(1, 5).all(), "Sample submission ratings must be in [1, 5]"
