from pathlib import Path
import os
import re
import csv
import random
import shutil
from collections import defaultdict
from typing import List, Tuple, Dict

# Deterministic seed for splits and sample submission label sampling
GLOBAL_SEED = 2025

# General regex to extract the lexical token between session and mic markers.
# Filenames follow: <SPK>_B<block>_<TOKEN>_M<mic>.wav
# TOKEN may be C10, CW37, D0, LA, etc. Capture the part between _B\d+_ and _M
LABEL_RE = re.compile(r"_B\d+_([^_]+)_M", re.IGNORECASE)


def _discover_wavs(raw: Path) -> List[Path]:
    wavs: List[Path] = []
    # Expected subfolders inside raw
    candidates = [raw / "noisereduced-uaspeech", raw / "noisereduced-uaspeech-control"]
    for base in candidates:
        if not base.exists():
            continue
        for root, _, files in os.walk(base):
            for fn in files:
                if fn.lower().endswith(".wav"):
                    wavs.append(Path(root) / fn)
    wavs = sorted(set(wavs))
    if not wavs:
        raise FileNotFoundError(f"No WAV files found under {raw}")
    return wavs


def _extract_label(filename: str) -> str:
    m = LABEL_RE.search(filename)
    if not m:
        raise ValueError(f"Could not parse label from filename: {filename}")
    return m.group(1).upper()


def _build_index(wavs: List[Path]) -> List[Tuple[Path, str]]:
    items: List[Tuple[Path, str]] = []
    for src in wavs:
        label = _extract_label(src.name)
        items.append((src, label))
    items.sort(key=lambda x: str(x[0]))  # deterministic id assignment by path
    return items


def _stratified_split(items: List[Tuple[Path, str]], test_ratio: float = 0.2) -> Tuple[List[int], List[int]]:
    by_label: Dict[str, List[int]] = defaultdict(list)
    for idx, (_, label) in enumerate(items):
        by_label[label].append(idx)

    train_idx: List[int] = []
    test_idx: List[int] = []

    for label, idxs in by_label.items():
        rnd = random.Random(GLOBAL_SEED + (hash(label) % (10**9)))
        idxs = idxs.copy()
        rnd.shuffle(idxs)
        n = len(idxs)
        if n <= 1:
            n_test = 0
        elif n <= 4:
            n_test = 1
        else:
            n_test = max(1, int(round(n * test_ratio)))
        if n - n_test < 1:
            n_test = max(0, n - 1)
        test_idx.extend(idxs[:n_test])
        train_idx.extend(idxs[n_test:])

    assert len(set(train_idx).intersection(test_idx)) == 0, "Overlap between train and test indices"
    assert len(train_idx) + len(test_idx) == len(items), "Split mismatch"
    assert len(train_idx) > 0 and len(test_idx) > 0, "Empty split encountered"

    # Ensure every test label exists in train
    train_labels = {items[i][1] for i in train_idx}
    test_labels = {items[i][1] for i in test_idx}
    missing = test_labels - train_labels
    if missing:
        for lbl in sorted(missing):
            for i in list(test_idx):
                if items[i][1] == lbl:
                    test_idx.remove(i)
                    train_idx.append(i)
                    break
        train_labels = {items[i][1] for i in train_idx}
        test_labels = {items[i][1] for i in test_idx}
        assert test_labels.issubset(train_labels), "Could not ensure test labels subset of train labels"
        assert len(test_idx) > 0, "Fixing label coverage moved all to train"

    return train_idx, test_idx


def _safe_link(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        if dst.exists():
            dst.unlink()
        os.link(src, dst)
    except Exception:
        shutil.copy2(src, dst)


def _write_csv(path: Path, rows: List[List[str]], header: List[str]):
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow(r)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation process.
    - Reads raw audio from `raw` directory
    - Creates train/test split deterministically and stratified by label
    - Writes metadata CSVs and audio files into `public` directory
    - Writes test_answer.csv into `private` directory
    - Copies description.txt into the `public` directory if present at the project root
    """
    random.seed(GLOBAL_SEED)

    # Discover and index
    wavs = _discover_wavs(raw)
    items = _build_index(wavs)

    # Canonical ids for published files
    ids = [f"audio_{i:06d}.wav" for i in range(len(items))]

    # Split
    train_idx, test_idx = _stratified_split(items, test_ratio=0.2)

    # Output paths
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_submission_csv = public / "sample_submission.csv"

    train_audio_dir.mkdir(parents=True, exist_ok=True)
    test_audio_dir.mkdir(parents=True, exist_ok=True)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Clean any previous prepared wavs to avoid leftovers
    for d in (train_audio_dir, test_audio_dir):
        for fn in d.glob("audio_*.wav"):
            try:
                fn.unlink()
            except Exception:
                pass

    train_rows: List[List[str]] = []
    test_rows: List[List[str]] = []
    test_ans_rows: List[List[str]] = []

    for i in train_idx:
        src, label = items[i]
        new_id = ids[i]
        _safe_link(src, train_audio_dir / new_id)
        train_rows.append([new_id, label])

    for i in test_idx:
        src, label = items[i]
        new_id = ids[i]
        _safe_link(src, test_audio_dir / new_id)
        test_rows.append([new_id])
        test_ans_rows.append([new_id, label])

    # Deterministic ordering
    train_rows.sort(key=lambda r: r[0])
    test_rows.sort(key=lambda r: r[0])
    test_ans_rows.sort(key=lambda r: r[0])

    # Write CSVs
    _write_csv(train_csv, train_rows, header=["id", "label"])
    _write_csv(test_csv, test_rows, header=["id"])
    _write_csv(test_answer_csv, test_ans_rows, header=["id", "label"])

    # Sample submission uses random label from training label set; deterministic RNG
    unique_train_labels = sorted({lbl for _, lbl in train_rows})
    assert unique_train_labels, "No labels in training set"
    rnd = random.Random(GLOBAL_SEED)
    sample_rows = [[rid, rnd.choice(unique_train_labels)] for [rid] in test_rows]
    sample_rows.sort(key=lambda r: r[0])
    _write_csv(sample_submission_csv, sample_rows, header=["id", "label"])

    # Copy description.txt into public if available at project root
    project_root = Path(__file__).resolve().parent
    desc_src = project_root / "description.txt"
    if desc_src.exists():
        shutil.copy2(desc_src, public / "description.txt")

    # Checks
    train_files = sorted([p.name for p in train_audio_dir.glob("*.wav")])
    test_files = sorted([p.name for p in test_audio_dir.glob("*.wav")])
    assert train_files == sorted([r[0] for r in train_rows]), "Mismatch train_audio vs train.csv"
    assert test_files == sorted([r[0] for r in test_rows]), "Mismatch test_audio vs test.csv"

    assert all(os.sep not in r[0] for r in train_rows), "Paths found in train.csv ids"
    assert all(os.sep not in r[0] for r in test_rows), "Paths found in test.csv ids"

    all_ids = [r[0] for r in train_rows] + [r[0] for r in test_rows]
    assert len(all_ids) == len(set(all_ids)), "Duplicate ids across splits"

    train_label_set = {r[1] for r in train_rows}
    test_label_set = {r[1] for r in test_ans_rows}
    assert test_label_set.issubset(train_label_set), "Unseen test labels present"

    assert sorted([r[0] for r in test_ans_rows]) == sorted([r[0] for r in test_rows]), "test_answer alignment"

    # Ensure no label leakage in filenames
    assert all('_C' not in fn.upper() and '_D' not in fn.upper() and '_L' not in fn.upper() for fn in train_files)
    assert all('_C' not in fn.upper() and '_D' not in fn.upper() and '_L' not in fn.upper() for fn in test_files)

    assert len(train_rows) > len(test_rows), "Train should be larger than test"
