from pathlib import Path
import os
import csv
import re
import shutil
import random
from collections import defaultdict
from typing import List, Tuple


def _natural_key(s: str):
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]


def _list_speaker_dirs(src_dir: Path) -> List[Path]:
    if not src_dir.is_dir():
        raise FileNotFoundError(f"Source directory not found: {src_dir}")
    speakers: List[Path] = []
    for p in sorted(src_dir.iterdir()):
        if p.is_dir() and any(f.suffix.lower() == ".wav" for f in p.iterdir()):
            speakers.append(p)
    if not speakers:
        raise RuntimeError("No speaker folders with WAV files found under source directory")
    return speakers


def _collect_wavs(speaker_dir: Path) -> List[Path]:
    wavs = sorted([p for p in speaker_dir.iterdir() if p.suffix.lower() == ".wav"], key=lambda p: _natural_key(p.name))
    return wavs


def _compute_split(n: int, test_ratio: float = 0.20) -> Tuple[int, int]:
    if n <= 0:
        return 0, 0
    n_test = max(1, int(round(n * test_ratio)))
    if n > 1 and n_test >= n:
        n_test = n - 1
    n_train = n - n_test
    return n_train, n_test


def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def _link_or_copy(src: Path, dst: Path):
    if dst.exists():
        return
    try:
        dst.parent.mkdir(parents=True, exist_ok=True)
        os.link(str(src), str(dst))
    except Exception:
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(str(src), str(dst))


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process.

    Inputs:
    - raw: absolute path to directory containing original data. Expects raw/50_speakers_audio_data/...
    - public: absolute path to directory where public files are written
    - private: absolute path to directory where hidden answers are written

    Outputs written:
    - public/train_audio/ and public/test_audio/ with anonymized clip names
    - public/train_metadata.csv (file_id, filepath, label)
    - public/test_data.csv (file_id, filepath)
    - public/sample_submission.csv (file_id, label)
    - public/description.txt (copied from repository root if present)
    - private/test_answer.csv (file_id, label)
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Paths must be absolute"

    source_dir = raw / "50_speakers_audio_data"
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    train_meta_csv = public / "train_metadata.csv"
    test_data_csv = public / "test_data.csv"
    sample_sub_csv = public / "sample_submission.csv"
    test_answer_csv = private / "test_answer.csv"

    # Clean previous outputs if any
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    _ensure_dir(public)
    _ensure_dir(private)
    _ensure_dir(train_audio_dir)
    _ensure_dir(test_audio_dir)

    speakers = _list_speaker_dirs(source_dir)

    # Deterministic label map by sorted folder name
    label_map = {sp: f"S{idx:03d}" for idx, sp in enumerate(sorted(speakers, key=lambda p: p.name))}

    train_rows: List[Tuple[str, str, str]] = []  # (file_id, relpath, label)
    test_rows: List[Tuple[str, str]] = []        # (file_id, relpath)
    answer_rows: List[Tuple[str, str]] = []      # (file_id, label)

    file_id_counter = 0

    for sp_dir in speakers:
        label = label_map[sp_dir]
        wavs = _collect_wavs(sp_dir)
        n = len(wavs)
        if n == 0:
            continue
        n_train, n_test = _compute_split(n)
        train_files = wavs[:n_train]
        test_files = wavs[n_train:]

        # Assertions per speaker
        assert len(train_files) + len(test_files) == n, "Split mismatch"
        if n == 1:
            assert len(train_files) == 0 and len(test_files) == 1
        if n > 1:
            assert len(train_files) >= 1 and len(test_files) >= 1

        for split, files in (("train", train_files), ("test", test_files)):
            for src in files:
                file_id_counter += 1
                file_id = f"clip_{file_id_counter:06d}"
                rel = Path("train_audio" if split == "train" else "test_audio") / f"{file_id}.wav"
                dst = public / rel
                _link_or_copy(src, dst)

                if split == "train":
                    train_rows.append((file_id, str(rel.as_posix()), label))
                else:
                    test_rows.append((file_id, str(rel.as_posix())))
                    answer_rows.append((file_id, label))

    # Global checks
    assert train_rows and test_rows, "Empty train or test set"
    train_ids = {fid for fid, _, _ in train_rows}
    test_ids = {fid for fid, _ in test_rows}
    assert train_ids.isdisjoint(test_ids), "Train and test IDs must be disjoint"

    # Ensure referenced files exist
    for _, rel, _ in train_rows:
        assert (public / rel).is_file(), f"Missing train audio file: {rel}"
    for _, rel in test_rows:
        assert (public / rel).is_file(), f"Missing test audio file: {rel}"

    # Write CSVs
    with train_meta_csv.open("w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["file_id", "filepath", "label"])
        for row in train_rows:
            w.writerow(row)

    with test_data_csv.open("w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["file_id", "filepath"])
        for row in test_rows:
            w.writerow(row)

    with test_answer_csv.open("w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["file_id", "label"])
        for row in answer_rows:
            w.writerow(row)

    # Create a sample submission with random but valid labels drawn from train label set
    labels_list = sorted({lbl for _, _, lbl in train_rows})
    random.seed(0)
    with sample_sub_csv.open("w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["file_id", "label"])
        for fid, _ in test_rows:
            w.writerow([fid, random.choice(labels_list)])

    # Copy description.txt into public if present at repo root
    repo_root = public.parent
    root_description = repo_root / "description.txt"
    if root_description.is_file():
        shutil.copy2(str(root_description), str(public / "description.txt"))

    # Final checks about schema consistency
    # Ensure every label in private answers is present in training labels
    train_labels = {lbl for _, _, lbl in train_rows}
    answer_labels = {lbl for _, lbl in answer_rows}
    assert answer_labels.issubset(train_labels), "Unseen labels detected in test answers"

    # Ensure sample submission ids match test ids
    with sample_sub_csv.open("r", newline="") as f:
        reader = csv.DictReader(f)
        sub_ids = [r["file_id"] for r in reader]
    assert set(sub_ids) == test_ids, "sample_submission.csv must include all test file_ids exactly once"

    # Ensure public/private dirs exist
    assert public.exists() and private.exists()
