from __future__ import annotations

import hashlib
import os
import random
import shutil
from pathlib import Path
from typing import Dict, List, Tuple
import csv
import json

# Deterministic behavior
SEED = 42
random.seed(SEED)

# Regex-free speaker extraction compatible with the dataset naming:
# Most filenames look like '00f0204f_nohash_0.wav' where speaker is the prefix before first underscore.

def _is_wav(path: Path) -> bool:
    return path.is_file() and path.suffix.lower() == ".wav"


def _list_label_dirs(raw: Path) -> List[Path]:
    """List candidate label directories under raw, excluding hidden/system and background noise."""
    dirs: List[Path] = []
    for p in raw.iterdir():
        if not p.is_dir():
            continue
        if p.name.startswith("."):
            continue
        if p.name.startswith("_"):
            # exclude background noise etc.
            continue
        # keep only dirs that contain wav files
        if any(_is_wav(child) for child in p.iterdir() if child.is_file()):
            dirs.append(p)
    dirs.sort(key=lambda x: x.name)
    return dirs


def _get_speaker_id(file_path: Path) -> str:
    stem = file_path.stem
    return stem.split("_")[0] if "_" in stem else stem


def _hash_frac(s: str) -> float:
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h, 16) / float(16 ** len(h))


def _clean_dir(p: Path):
    if p.exists():
        shutil.rmtree(p)
    p.mkdir(parents=True, exist_ok=True)


def _safe_write_csv(path: Path, header: List[str], rows: List[List[str]]):
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow(r)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process for the Speech Commands classification task.

    Input directories (absolute Path objects):
      - raw: contains original dataset folders with WAV files
      - public: target directory for public artifacts (train/test csv, data, sample_submission, description)
      - private: target directory for private artifacts (test_answer.csv)

    Outputs (created deterministically):
      Public directory:
        - train.csv  (columns: id,label)
        - test.csv   (columns: id)
        - sample_submission.csv (columns: id,label)
        - train_audio/ (copied wav files referenced by train.csv)
        - test_audio/  (copied wav files referenced by test.csv)
        - description.txt (copied from repository root if available)
      Private directory:
        - test_answer.csv (columns: id,label)
    """
    # Absolute paths required
    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Clean or create target data directories and CSVs
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    _clean_dir(train_audio_dir)
    _clean_dir(test_audio_dir)

    # Remove possible stale csvs
    for fp in [public / "train.csv", public / "test.csv", public / "sample_submission.csv", private / "test_answer.csv"]:
        if fp.exists():
            fp.unlink()

    # Discover labels and files
    label_dirs = _list_label_dirs(raw)
    labels = [p.name for p in label_dirs]
    assert len(labels) > 1, "Expected multiple label folders under raw/."

    # Collect files per label
    label_to_files: Dict[str, List[Path]] = {}
    for d in label_dirs:
        files = sorted([p for p in d.iterdir() if _is_wav(p)], key=lambda x: x.name)
        if len(files) == 0:
            continue
        label_to_files[d.name] = files

    # Build global file records (path, label, speaker)
    TEST_SPEAKER_RATIO = 0.20
    speakers = set()
    file_recs: List[Tuple[Path, str, str]] = []
    for label in labels:
        files = label_to_files.get(label, [])
        for fp in files:
            spk = _get_speaker_id(fp)
            speakers.add(spk)
            file_recs.append((fp, label, spk))

    # Deterministic speaker split
    test_speakers = {spk for spk in speakers if _hash_frac(spk) < TEST_SPEAKER_RATIO}

    def _assign_split() -> Tuple[List[Tuple[Path, str, str]], List[Tuple[Path, str, str]]]:
        train, test = [], []
        for pth, lbl, spk in file_recs:
            (test if spk in test_speakers else train).append((pth, lbl, spk))
        return train, test

    train_recs, test_recs = _assign_split()

    # Ensure every label appears in both splits: if missing in test, move one speaker to test
    def _labels_in(recs):
        return {lbl for _, lbl, _ in recs}

    te_labels = _labels_in(test_recs)
    for lbl in labels:
        if lbl not in te_labels:
            # Choose a deterministic speaker that has this label in train
            candidates = sorted({spk for _, l, spk in train_recs if l == lbl})
            if candidates:
                test_speakers.add(candidates[0])
    train_recs, test_recs = _assign_split()

    # Final coverage checks
    assert set(labels).issubset(_labels_in(train_recs)), "All labels must appear in train set."
    assert set(labels).issubset(_labels_in(test_recs)), "All labels must appear in test set."

    # Assign anonymized IDs and copy files into public/train_audio and public/test_audio
    train_pairs: List[Tuple[str, str]] = []  # (id, label)
    test_pairs: List[Tuple[str, str]] = []   # (id, label)

    tr_idx = 1
    for src, lbl, _ in train_recs:
        new_id = f"clip_tr_{tr_idx:07d}.wav"
        shutil.copy2(str(src), str(train_audio_dir / new_id))
        train_pairs.append((new_id, lbl))
        tr_idx += 1

    te_idx = 1
    for src, lbl, _ in test_recs:
        new_id = f"clip_te_{te_idx:07d}.wav"
        shutil.copy2(str(src), str(test_audio_dir / new_id))
        test_pairs.append((new_id, lbl))
        te_idx += 1

    # Write CSVs
    _safe_write_csv(public / "train.csv", ["id", "label"], [[i, l] for i, l in train_pairs])
    _safe_write_csv(public / "test.csv", ["id"], [[i] for i, _ in test_pairs])
    _safe_write_csv(private / "test_answer.csv", ["id", "label"], [[i, l] for i, l in test_pairs])

    # Sample submission (choose a valid placeholder label deterministically)
    rng = random.Random(SEED)
    _safe_write_csv(public / "sample_submission.csv", ["id", "label"], [[i, rng.choice(labels)] for i, _ in test_pairs])

    # Write labels.json to public for convenience (optional but helpful)
    with (public / "labels.json").open("w", encoding="utf-8") as f:
        json.dump({"labels": labels}, f, ensure_ascii=False, indent=2)

    # Copy description.txt into public if exists at repo root
    repo_description = (raw.parent / "description.txt").resolve()
    if repo_description.exists():
        shutil.copy2(str(repo_description), str(public / "description.txt"))

    # Assertions to ensure integrity
    # IDs
    train_ids = [i for i, _ in train_pairs]
    test_ids = [i for i, _ in test_pairs]
    assert len(set(train_ids)) == len(train_ids), "Duplicate ids in train.csv"
    assert len(set(test_ids)) == len(test_ids), "Duplicate ids in test/test_answer"
    assert set(train_ids).isdisjoint(set(test_ids)), "Train and test ids should not overlap"

    # Files exist
    for i in train_ids:
        assert (train_audio_dir / i).is_file(), f"Missing train audio {i}"
    for i in test_ids:
        assert (test_audio_dir / i).is_file(), f"Missing test audio {i}"

    # Public CSV sanity
    # Reload and check columns
    import pandas as pd

    df_tr = pd.read_csv(public / "train.csv")
    df_te = pd.read_csv(public / "test.csv")
    df_ans = pd.read_csv(private / "test_answer.csv")
    assert df_tr.columns.tolist() == ["id", "label"]
    assert df_te.columns.tolist() == ["id"]
    assert df_ans.columns.tolist() == ["id", "label"]
    assert set(df_tr["label"]) <= set(labels)
    assert set(df_ans["label"]) <= set(labels)
    assert set(df_te["id"]) == set(df_ans["id"]) == set(test_ids)

    # Ensure sample submission matches test ids and valid labels
    df_sample = pd.read_csv(public / "sample_submission.csv")
    assert df_sample.columns.tolist() == ["id", "label"]
    assert set(df_sample["id"]) == set(test_ids)
    assert set(df_sample["label"]) <= set(labels)
