from pathlib import Path
import shutil
import random
import pandas as pd
from collections import defaultdict

SEED = 42
random.seed(SEED)


def _safe_mkdir(dir_path: Path, clear: bool = True) -> None:
    if clear and dir_path.exists():
        shutil.rmtree(dir_path)
    dir_path.mkdir(parents=True, exist_ok=True)


def _load_metadata(raw: Path) -> pd.DataFrame:
    meta_path = raw / "bird_songs_metadata.csv"
    assert meta_path.exists(), f"Metadata not found: {meta_path}"
    df = pd.read_csv(meta_path, dtype=str)

    # Normalize string columns we care about
    for col in [
        "id",
        "genus",
        "species",
        "name",
        "filename",
        "sound_type",
    ]:
        if col in df.columns:
            df[col] = df[col].astype(str)

    # Link to source audio path
    audio_dir = raw / "wavfiles"
    assert audio_dir.exists() and audio_dir.is_dir(), f"Missing audio dir: {audio_dir}"
    df["source_filepath"] = df["filename"].apply(lambda x: (audio_dir / x).as_posix())
    df["exists"] = df["source_filepath"].apply(lambda p: Path(p).exists())
    # Drop rows without audio
    if (~df["exists"]).any():
        df = df[df["exists"]].copy()

    # Keep songs only when available
    if "sound_type" in df.columns:
        df = df[df["sound_type"].str.lower().str.contains("song", na=False)].copy()

    # Define target label
    assert "name" in df.columns, "Expected 'name' column for species common name"
    df["label"] = df["name"].str.strip()
    assert df["label"].nunique() >= 2, "Need at least two classes"

    return df


def _grouped_stratified_split(
    df: pd.DataFrame, group_col: str, label_col: str, test_frac: float, seed: int
):
    # Validate groups are single-labeled
    grp_nlabels = df.groupby(group_col)[label_col].nunique()
    bad = grp_nlabels[grp_nlabels != 1]
    assert bad.empty, "Found groups with mixed labels; cannot group-split safely"

    rng = random.Random(seed)

    # Map label -> unique groups
    label_to_groups: dict[str, list] = defaultdict(list)
    for g, sub in df.groupby(group_col):
        label = sub[label_col].iloc[0]
        label_to_groups[label].append(g)

    test_groups: set = set()
    for label, groups in label_to_groups.items():
        groups = list(groups)
        rng.shuffle(groups)
        n_total = len(groups)
        n_test = max(1, round(n_total * test_frac))
        if n_total - n_test < 1:
            n_test = n_total - 1  # ensure at least one train group
        test_groups.update(groups[:n_test])

    is_test = df[group_col].isin(test_groups)
    df_test = df[is_test].copy()
    df_train = df[~is_test].copy()

    # Sanity checks
    assert set(df_test[label_col].unique()).issubset(set(df_train[label_col].unique()))

    return df_train, df_test


def _copy_audio(file_list: list[str], dst_dir: Path) -> None:
    _safe_mkdir(dst_dir, clear=True)
    for fn in file_list:
        src = Path(fn)
        dst = dst_dir / src.name
        shutil.copy2(src, dst)
    # Ensure all copied
    for fn in file_list:
        assert (dst_dir / Path(fn).name).exists(), f"Missing copied file: {(dst_dir / Path(fn).name)}"


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation pipeline

    Inputs:
    - raw: absolute path to raw/ containing bird_songs_metadata.csv and wavfiles/
    - public: absolute path to public/ output directory
    - private: absolute path to private/ output directory

    Artifacts created:
    - public/train.csv (id,label)
    - public/test.csv (id)
    - public/sample_submission.csv (id,label)
    - public/train_audio/ (wav files)
    - public/test_audio/ (wav files)
    - public/description.txt (copied from repository root description)
    - private/test_answer.csv (id,label)
    """
    # Preconditions
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Use absolute paths"

    # Ensure clean output dirs
    _safe_mkdir(public, clear=True)
    _safe_mkdir(private, clear=True)

    # Load and split
    df = _load_metadata(raw)
    df_train, df_test = _grouped_stratified_split(
        df, group_col="id", label_col="label", test_frac=0.2, seed=SEED
    )

    # Copy audio into public directory (keep original filenames; they do not leak labels)
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    _copy_audio(df_train["source_filepath"].tolist(), train_audio_dir)
    _copy_audio(df_test["source_filepath"].tolist(), test_audio_dir)

    # Build CSVs using the filenames as ids
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    sample_csv = public / "sample_submission.csv"
    answer_csv = private / "test_answer.csv"

    train_df = pd.DataFrame({"id": df_train["filename"].tolist(), "label": df_train["label"].tolist()})
    test_ids = pd.DataFrame({"id": df_test["filename"].tolist()})
    answer_df = pd.DataFrame({"id": df_test["filename"].tolist(), "label": df_test["label"].tolist()})

    # Deterministic sorting for reproducibility
    train_df = train_df.sort_values("id").reset_index(drop=True)
    test_ids = test_ids.sort_values("id").reset_index(drop=True)
    answer_df = answer_df.sort_values("id").reset_index(drop=True)

    # Sample submission with deterministic labels (use first train label)
    labels_sorted = sorted(train_df["label"].unique().tolist())
    default_label = labels_sorted[0]
    sample_df = pd.DataFrame({"id": test_ids["id"], "label": [default_label] * len(test_ids)})

    # Write files
    train_df.to_csv(train_csv, index=False)
    test_ids.to_csv(test_csv, index=False)
    sample_df.to_csv(sample_csv, index=False)
    answer_df.to_csv(answer_csv, index=False)

    # Copy description.txt into public/
    repo_desc = raw.parent / "description.txt"
    if repo_desc.exists():
        shutil.copy2(repo_desc, public / "description.txt")

    # Checks
    assert train_csv.exists(), "public/train.csv should exist"
    assert test_csv.exists(), "public/test.csv should exist"
    assert sample_csv.exists(), "public/sample_submission.csv should exist"
    assert answer_csv.exists(), "private/test_answer.csv should exist"

    # Id-path hygiene
    for df_chk, name in [(train_df, "train"), (test_ids, "test"), (answer_df, "answer"), (sample_df, "sample")]:
        assert not df_chk["id"].astype(str).str.contains(r"/|\\\\").any(), f"{name} ids must not be paths"

    # Alignment
    assert set(test_ids["id"]) == set(answer_df["id"]), "public/test.csv and private/test_answer.csv must share ids"
    assert train_df["id"].is_unique and test_ids["id"].is_unique and answer_df["id"].is_unique

    # Audio consistency
    assert len(list(train_audio_dir.glob("*.wav"))) == len(train_df)
    assert len(list(test_audio_dir.glob("*.wav"))) == len(test_ids)

    # No overlap
    assert set(train_df["id"]).isdisjoint(set(test_ids["id"]))
