from pathlib import Path
import os
import shutil
from typing import List, Tuple

import numpy as np
import pandas as pd

# Deterministic behavior
SEED = 42

# Target labels included in the task (soft labels: counts/3)
TARGET_LABELS = [
    "Prolongation",
    "Block",
    "SoundRep",
    "WordRep",
    "Interjection",
    "NoStutteredWords",
    "PoorAudioQuality",
    "DifficultToUnderstand",
    "NaturalPause",
    "Music",
    "NoSpeech",
]


def _read_labels_csv(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path, skipinitialspace=True)
    df.columns = [c.strip() for c in df.columns]
    # Validate expected id columns
    for col in ["Show", "EpId", "ClipId"]:
        if col not in df.columns:
            raise RuntimeError(f"Unexpected label file format at {path}; missing column '{col}'")
    # Ensure label columns exist
    for col in TARGET_LABELS:
        if col not in df.columns:
            raise RuntimeError(f"Column '{col}' not found in {path}")
    keep_cols = ["Show", "EpId", "ClipId"] + TARGET_LABELS
    df = df[keep_cols].copy()
    df["Show"] = df["Show"].astype(str).str.strip()
    df["EpId"] = df["EpId"].astype(str).str.strip()
    # ClipId to int
    try:
        df["ClipId"] = df["ClipId"].astype(int)
    except Exception:
        df["ClipId"] = df["ClipId"].astype(str).str.strip().astype(int)
    return df


def _index_audio_files(audio_dir: Path) -> set:
    if not audio_dir.is_dir():
        raise FileNotFoundError(f"Audio directory not found: {audio_dir}")
    return set([f.name for f in audio_dir.glob("*.wav")])


def _possible_filenames(show: str, ep: str, clipid: int) -> List[str]:
    candidates = []
    ep_clean = ep.strip()
    candidates.append(f"{show}_{ep_clean}_{clipid}.wav")
    for width in [2, 3, 4]:
        candidates.append(f"{show}_{str(ep_clean).zfill(width)}_{clipid}.wav")
    try:
        ep_int = int(ep_clean)
        candidates.append(f"{show}_{ep_int}_{clipid}.wav")
    except Exception:
        pass
    # dedupe
    uniq, seen = [], set()
    for c in candidates:
        if c not in seen:
            uniq.append(c)
            seen.add(c)
    return uniq


def _match_rows_to_audio(df: pd.DataFrame, available_files: set) -> pd.DataFrame:
    matched_files = []
    for _, row in df.iterrows():
        show, ep, clipid = row["Show"], row["EpId"], int(row["ClipId"])
        found = None
        for cand in _possible_filenames(show, ep, clipid):
            if cand in available_files:
                found = cand
                break
        matched_files.append(found)
    df = df.copy()
    df["orig_filename"] = matched_files
    df = df[~df["orig_filename"].isna()].reset_index(drop=True)
    return df


def _build_targets(df: pd.DataFrame) -> pd.DataFrame:
    for col in TARGET_LABELS:
        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
        df[col] = (df[col] / 3.0).clip(0.0, 1.0)
    return df


def _deterministic_split(n: int, test_ratio: float = 0.2, seed: int = SEED) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.RandomState(seed)
    idx = np.arange(n)
    perm = rng.permutation(n)
    test_size = int(round(n * test_ratio))
    return perm[test_size:], perm[:test_size]


def _ensure_label_coverage(Y: np.ndarray, train_idx: np.ndarray, test_idx: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    # Attempt to ensure each label that has positives overall appears in both splits
    def has_pos(idxs: np.ndarray) -> np.ndarray:
        if idxs.size == 0:
            return np.zeros(Y.shape[1], dtype=bool)
        return (Y[idxs] > 0).any(axis=0)

    for _ in range(10):
        moved = False
        train_pos = has_pos(train_idx)
        test_pos = has_pos(test_idx)
        needs_train = np.where(~train_pos & (Y.sum(axis=0) > 0))[0]
        needs_test = np.where(~test_pos & (Y.sum(axis=0) > 0))[0]
        for lbl in needs_train:
            cand = np.where(Y[test_idx, lbl] > 0)[0]
            if cand.size > 0:
                j = cand[0]
                idx_global = test_idx[j]
                test_idx = np.delete(test_idx, j)
                train_idx = np.append(train_idx, idx_global)
                moved = True
        for lbl in needs_test:
            cand = np.where(Y[train_idx, lbl] > 0)[0]
            if cand.size > 0:
                j = cand[0]
                idx_global = train_idx[j]
                train_idx = np.delete(train_idx, j)
                test_idx = np.append(test_idx, idx_global)
                moved = True
        if not moved:
            break
    return train_idx, test_idx


def _safe_mkdir(path: Path):
    if path.exists():
        shutil.rmtree(path)
    path.mkdir(parents=True, exist_ok=True)


def prepare(raw: Path, public: Path, private: Path):
    # Input paths
    labels_sep = raw / "SEP-28k_labels.csv"
    labels_flu = raw / "fluencybank_labels.csv"
    audio_dir = raw / "clips" / "stuttering-clips" / "clips"

    if not labels_sep.is_file():
        raise FileNotFoundError(f"Missing labels file: {labels_sep}")
    if not labels_flu.is_file():
        raise FileNotFoundError(f"Missing labels file: {labels_flu}")
    if not audio_dir.is_dir():
        raise FileNotFoundError(f"Missing audio directory: {audio_dir}")

    # Create output structure
    _safe_mkdir(public)
    _safe_mkdir(private)
    public_audio = public / "audio"
    train_audio = public_audio / "train"
    test_audio = public_audio / "test"
    _safe_mkdir(public_audio)
    _safe_mkdir(train_audio)
    _safe_mkdir(test_audio)

    # Load labels and match to audio
    df_sep = _read_labels_csv(labels_sep)
    df_flu = _read_labels_csv(labels_flu)
    df_all = pd.concat([df_sep, df_flu], ignore_index=True)

    available = _index_audio_files(audio_dir)
    df_all = _match_rows_to_audio(df_all, available)
    if len(df_all) == 0:
        raise RuntimeError("No labeled audio files matched the raw audio directory.")

    # Build targets and anonymized ids
    df_all = _build_targets(df_all).reset_index(drop=True)
    df_all["id"] = [f"clip_{i:06d}" for i in range(1, len(df_all) + 1)]

    # Deterministic split with coverage
    Y = df_all[TARGET_LABELS].to_numpy()
    train_idx, test_idx = _deterministic_split(len(df_all), test_ratio=0.2, seed=SEED)
    train_idx, test_idx = _ensure_label_coverage(Y, train_idx, test_idx)

    df_train = df_all.iloc[train_idx].reset_index(drop=True)
    df_test = df_all.iloc[test_idx].reset_index(drop=True)

    # Copy audio to public/audio/{train,test}/{id}.wav
    def copy_rows(rows: pd.DataFrame, dest: Path):
        for _, r in rows.iterrows():
            src = audio_dir / r["orig_filename"]
            dst = dest / f"{r['id']}.wav"
            shutil.copy2(src, dst)

    copy_rows(df_train, train_audio)
    copy_rows(df_test, test_audio)

    # Prepare CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    train_out = df_train[["id"] + TARGET_LABELS].copy()
    test_out = df_test[["id"]].copy()
    answer_out = df_test[["id"] + TARGET_LABELS].copy()

    # Write CSVs
    train_out.to_csv(train_csv, index=False)
    test_out.to_csv(test_csv, index=False)
    answer_out.to_csv(test_answer_csv, index=False)

    # Sample submission: random probabilities in (0,1)
    rng = np.random.RandomState(SEED)
    sample = df_test[["id"]].copy()
    for col in TARGET_LABELS:
        sample[col] = rng.uniform(0.01, 0.99, size=len(sample)).round(6)
    sample.to_csv(sample_sub_csv, index=False)

    # Copy description to public
    root_desc = raw.parent / "description.txt"
    if root_desc.is_file():
        shutil.copy2(root_desc, public / "description.txt")

    # Checks
    assert (public / "audio" / "train").exists()
    assert (public / "audio" / "test").exists()
    assert train_out["id"].is_unique and test_out["id"].is_unique
    assert set(train_out["id"]).isdisjoint(set(test_out["id"]))
    assert set(test_out["id"]) == set(answer_out["id"]) == set(sample["id"])  # ids match

    # Ensure every label that exists overall has positives in both splits
    for col in TARGET_LABELS:
        total_pos = (df_all[col] > 0).sum()
        if total_pos > 0:
            assert (train_out[col] > 0).any(), f"No positive examples in train for label {col}"
            assert (answer_out[col] > 0).any(), f"No positive examples in test for label {col}"
