from pathlib import Path
import shutil
import hashlib
import random
import csv
from collections import defaultdict
from typing import Dict, List, Tuple


def _sha1_id(text: str, length: int = 12) -> str:
    return hashlib.sha1(text.encode("utf-8")).hexdigest()[:length]


def _list_audio_files(audio_root: Path) -> List[Path]:
    # Collect mp3 files up to two levels deep: <root>/*/*.mp3 and <root>/*/*/*.mp3
    files: List[Path] = []
    files.extend(audio_root.glob("*/*.mp3"))
    files.extend(audio_root.glob("*/*/*.mp3"))
    # Filter files only
    files = [f for f in files if f.is_file() and f.suffix.lower() == ".mp3"]
    # Sort deterministically
    files.sort()
    return files


def _derive_label_from_path(path: Path) -> str:
    # Expect .../<Label>_sound/<file>.mp3 but be robust
    parent = path.parent.name
    label = parent[:-6] if parent.endswith("_sound") else parent
    return label.strip()


def _stratified_split(items_by_label: Dict[str, List[Path]], rng: random.Random, target_test_frac: float = 0.2) -> Tuple[List[Tuple[str, Path]], List[Tuple[str, Path]]]:
    train_items: List[Tuple[str, Path]] = []
    test_items: List[Tuple[str, Path]] = []
    for label, items in sorted(items_by_label.items()):
        items_sorted = sorted(items)
        rng.shuffle(items_sorted)
        n = len(items_sorted)
        if n <= 1:
            n_test = 0
        elif n <= 4:
            n_test = 1
        else:
            n_test = max(1, int(round(target_test_frac * n)))
            if n - n_test < 1:
                n_test = n - 1
        test_items.extend([(label, f) for f in items_sorted[:n_test]])
        train_items.extend([(label, f) for f in items_sorted[n_test:]])
    return train_items, test_items


def prepare(raw: Path, public: Path, private: Path):
    """
    Full preparation pipeline.

    Inputs:
    - raw: absolute path to the raw data directory (contains the original dataset)
    - public: absolute path to the public output directory
    - private: absolute path to the private output directory

    Outputs in public/:
    - train_audio/  .mp3 files named as <id>.mp3
    - test_audio/   .mp3 files named as <id>.mp3
    - train.csv     columns: id,label
    - test.csv      columns: id
    - labels.csv    column: label
    - sample_submission.csv columns: id,label
    - description.txt copied into public/

    Outputs in private/:
    - test_answer.csv columns: id,label
    """
    # Sanity on absolute paths
    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Detect audio root
    # Expect raw/Voice of Birds/Voice of Birds/
    audio_root = raw / "Voice of Birds" / "Voice of Birds"
    if not audio_root.exists():
        # Fallback: if user passes already the inner folder, use it directly
        candidate = raw / "Voice of Birds"
        if candidate.exists() and any((candidate / d).is_dir() for d in candidate.iterdir() if (candidate / d).is_dir()):
            audio_root = candidate
        else:
            raise FileNotFoundError(f"Audio root not found under {raw}. Expected {audio_root}")

    # Discover mp3 files
    all_audio_files = _list_audio_files(audio_root)
    if len(all_audio_files) == 0:
        raise RuntimeError("No audio files found to prepare.")

    # Build mapping label -> files
    items_by_label: Dict[str, List[Path]] = defaultdict(list)
    for f in all_audio_files:
        items_by_label[_derive_label_from_path(f)].append(f)

    labels = sorted(items_by_label.keys())
    if len(labels) < 2:
        raise RuntimeError("Need at least two classes to form a classification competition.")

    # Shuffle and split stratified
    rng = random.Random(42)
    train_pairs, test_pairs = _stratified_split(items_by_label, rng, target_test_frac=0.2)
    if len(test_pairs) == 0 or len(train_pairs) == 0:
        raise RuntimeError("Empty split encountered; please check the raw data.")

    # Output dirs
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    train_audio_dir.mkdir(parents=True, exist_ok=True)
    test_audio_dir.mkdir(parents=True, exist_ok=True)

    # Generate stable anonymized ids
    used_ids = set()

    def new_id_for(path: Path) -> str:
        base_id = _sha1_id(str(path.resolve()), length=12)
        candidate = base_id
        i = 1
        while candidate in used_ids:
            candidate = f"{base_id}_{i}"
            i += 1
        used_ids.add(candidate)
        return candidate

    # Create CSV rows and copy audio
    train_rows: List[Tuple[str, str]] = []
    test_rows: List[Tuple[str]] = []
    answer_rows: List[Tuple[str, str]] = []

    for label, src in train_pairs:
        _id = new_id_for(src)
        dst = train_audio_dir / f"{_id}.mp3"
        if not dst.exists():
            shutil.copy2(src, dst)
        train_rows.append((_id, label))

    for label, src in test_pairs:
        _id = new_id_for(src)
        dst = test_audio_dir / f"{_id}.mp3"
        if not dst.exists():
            shutil.copy2(src, dst)
        test_rows.append((_id,))
        answer_rows.append((_id, label))

    # Write CSVs to public and private
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    labels_csv = public / "labels.csv"
    sample_sub_csv = public / "sample_submission.csv"
    test_answer_csv = private / "test_answer.csv"

    with train_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for row in sorted(train_rows):
            w.writerow(row)

    with test_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id"])
        for row in sorted(test_rows):
            w.writerow(row)

    with labels_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["label"])
        for lbl in labels:
            w.writerow([lbl])

    with test_answer_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for row in sorted(answer_rows):
            w.writerow(row)

    # Sample submission with random valid labels
    rng2 = random.Random(2025)
    test_ids_sorted = [r[0] for r in sorted(test_rows)]
    with sample_sub_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for tid in test_ids_sorted:
            w.writerow([tid, rng2.choice(labels)])

    # Copy description.txt into public/
    root_desc = Path(__file__).resolve().parent / "description.txt"
    if root_desc.exists():
        shutil.copy2(root_desc, public / "description.txt")

    # Basic checks
    assert len(used_ids) == len(train_rows) + len(test_rows), "ID collision detected"
    # Files exist
    for _id, _label in train_rows[:10]:
        assert (train_audio_dir / f"{_id}.mp3").is_file()
    for (_id,) in test_rows[:10]:
        assert (test_audio_dir / f"{_id}.mp3").is_file()
    # test.csv vs test_answer.csv ids
    with test_csv.open("r", encoding="utf-8") as f:
        t_ids = [r["id"] for r in csv.DictReader(f)]
    with test_answer_csv.open("r", encoding="utf-8") as f:
        a_ids = [r["id"] for r in csv.DictReader(f)]
    assert set(t_ids) == set(a_ids), "Mismatch between public/test.csv and private/test_answer.csv IDs"
    # Ensure no leakage via filenames (labels should not be substrings of ids)
    def _contains_label_name(fn: str) -> bool:
        s = fn.lower().replace(" ", "")
        for lbl in labels:
            if lbl.lower().replace(" ", "") in s:
                return True
        return False

    leak_train = [p.name for p in train_audio_dir.glob("*.mp3") if _contains_label_name(p.stem)]
    leak_test = [p.name for p in test_audio_dir.glob("*.mp3") if _contains_label_name(p.stem)]
    assert not leak_train and not leak_test, f"Potential label leakage in filenames: {leak_train}, {leak_test}"

    # Ensure sample_submission matches test ids and uses valid labels
    with sample_sub_csv.open("r", encoding="utf-8") as f:
        subs = list(csv.DictReader(f))
    assert len(subs) == len(t_ids)
    label_set = set(labels)
    for r in subs:
        assert r["id"] in set(t_ids)
        assert r["label"] in label_set

    # Ensure no labels or answers appear in public beyond sample_submission
    # Only train.csv has labels by design, and sample_submission has fake labels
    assert train_csv.is_file() and test_csv.is_file() and labels_csv.is_file() and sample_sub_csv.is_file()