from pathlib import Path
from typing import List
import shutil
import numpy as np
import pandas as pd

# Deterministic behavior
RANDOM_SEED = 42
TEST_FRACTION = 0.15


def _find_target_columns(columns: List[str]) -> List[str]:
    targets = [c for c in columns if c.startswith("Good for ")]
    targets.sort()
    return targets


def _ensure_dirs(public: Path, private: Path):
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation process.

    - Reads raw/spotify_dataset.csv
    - Creates deterministic train/test split with an integer id
    - Writes:
        public/train.csv, public/test.csv, public/sample_submission.csv, public/description.txt
        private/test_answer.csv
    - No leakage of targets in public/test.csv
    - Deterministic split via fixed random seed
    """
    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    _ensure_dirs(public, private)

    # Load source CSV (utf-8-sig to handle potential BOM)
    source_csv = raw / "spotify_dataset.csv"
    if not source_csv.exists():
        raise FileNotFoundError(f"Missing source CSV: {source_csv}")
    df = pd.read_csv(source_csv, encoding="utf-8-sig")

    # Identify target columns
    target_cols = _find_target_columns(df.columns.tolist())
    if not target_cols:
        raise RuntimeError("No target columns found (expected columns starting with 'Good for ').")

    # Create id column deterministically (0..N-1)
    n = len(df)
    df.insert(0, "id", np.arange(n, dtype=np.int64))

    # Coerce target columns to binary integers 0/1
    for c in target_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0).astype(int)
        if not set(df[c].unique()).issubset({0, 1}):
            raise ValueError(f"Target column {c} must be binary 0/1.")

    # Deterministic split indices
    rng = np.random.RandomState(RANDOM_SEED)
    indices = np.arange(n)
    perm = rng.permutation(indices)
    test_size = int(np.floor(TEST_FRACTION * n))
    test_idx = perm[:test_size]
    train_idx = perm[test_size:]

    # Ensure: for any label that has positives in test, train has at least one positive
    for c in target_cols:
        pos_all = set(df.index[df[c] == 1].tolist())
        if not pos_all:
            continue
        pos_test = [i for i in test_idx if i in pos_all]
        pos_train = [i for i in train_idx if i in pos_all]
        if len(pos_test) > 0 and len(pos_train) == 0:
            # Move one positive example from test to train
            move_i = pos_test[0]
            train_idx = np.append(train_idx, move_i)
            test_idx = np.array([i for i in test_idx if i != move_i])

    # Materialize splits
    train_df = df.loc[train_idx].copy()
    test_full_df = df.loc[test_idx].copy()

    # Public test: drop target columns (no leakage)
    test_df = test_full_df.drop(columns=target_cols)

    # Private answers: id + targets
    test_answer_df = test_full_df[["id"] + target_cols].copy()

    # Sample submission: id + targets with deterministic random probabilities in [0,1]
    rng2 = np.random.RandomState(RANDOM_SEED + 123)
    sample_submission = pd.DataFrame({"id": test_df["id"].values})
    for c in target_cols:
        sample_submission[c] = rng2.rand(len(sample_submission)).astype(np.float32)

    # Save CSVs
    train_cols = ["id"] + [c for c in train_df.columns if c != "id"]
    test_cols = ["id"] + [c for c in test_df.columns if c != "id"]
    answer_cols = ["id"] + target_cols
    sample_cols = ["id"] + target_cols

    train_path = public / "train.csv"
    test_path = public / "test.csv"
    answer_path = private / "test_answer.csv"
    sample_path = public / "sample_submission.csv"

    train_df[train_cols].to_csv(train_path, index=False)
    test_df[test_cols].to_csv(test_path, index=False)
    test_answer_df[answer_cols].to_csv(answer_path, index=False)
    sample_submission[sample_cols].to_csv(sample_path, index=False)

    # Copy description.txt into public
    desc_src = raw.parent / "description.txt"
    if desc_src.exists():
        shutil.copy(desc_src, public / "description.txt")

    # Integrity checks
    assert (public / "train.csv").exists(), "public/train.csv missing"
    assert (public / "test.csv").exists(), "public/test.csv missing"
    assert (public / "sample_submission.csv").exists(), "public/sample_submission.csv missing"
    assert (private / "test_answer.csv").exists(), "private/test_answer.csv missing"

    # No target leakage in public/test.csv
    for c in target_cols:
        assert c not in test_df.columns, f"Target leakage detected in public/test.csv: {c}"

    # Matching ids between sample submission and test answers
    assert set(sample_submission["id"]) == set(test_answer_df["id"]), "Sample submission ids must match test ids"

    # Column schema alignment for sample submission and test answers
    assert list(sample_submission.columns) == list(test_answer_df.columns), "Sample submission columns mismatch with test answers"

    # Ensure uniqueness and disjointness of ids
    assert train_df["id"].nunique() == len(train_df)
    assert test_df["id"].nunique() == len(test_df)
    assert set(train_df["id"]).isdisjoint(set(test_df["id"]))

    # Ensure the text column exists in train (feature presence)
    assert "text" in train_df.columns, "Expected 'text' column in train.csv"
