from __future__ import annotations

import csv
import random
import re
import shutil
import uuid
from pathlib import Path
from typing import Dict, Iterable, Iterator, List, Tuple

# We implement a full preparation pipeline similar to samples/sample_prepare.py
# This function builds train/test splits and places files in the required
# public/ and private/ directories.

_WORD_RE = re.compile(r"[^a-z0-9' ]+")
RANDOM_SEED = 2025
random.seed(RANDOM_SEED)


def _normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = s.strip().lower()
    s = s.replace("-", " ")
    s = _WORD_RE.sub(" ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def _iter_transcripts(root: Path) -> Iterator[Tuple[str, str, Path, str]]:
    """Yield (utt_id, transcript, flac_path, speaker_id) for all .trans.txt under root."""
    for trans_path in root.rglob("*.trans.txt"):
        try:
            with trans_path.open("r", encoding="utf-8", errors="ignore") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split(maxsplit=1)
                    if len(parts) != 2:
                        continue
                    utt_id, transcript = parts
                    flac_fn = f"{utt_id}.flac"
                    flac_path = trans_path.parent / flac_fn
                    if not flac_path.is_file():
                        continue
                    spk_id = utt_id.split("-")[0]
                    yield utt_id, transcript.strip(), flac_path.resolve(), spk_id
        except Exception:
            # Skip any malformed files
            continue


def _safe_copy(src: Path, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        dst.unlink()
    shutil.copy2(src, dst)


def prepare(raw: Path, public: Path, private: Path):
    # raw is expected to contain LibriSpeech/ already extracted
    # public will contain: train.csv, test.csv, sample_submission.csv, audio_train/, audio_test/, description.txt
    # private will contain: test_answer.csv

    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    assert raw.exists(), f"Raw directory does not exist: {raw}"
    # Try to locate LibriSpeech root under raw
    libri_root = raw / "LibriSpeech"
    if not libri_root.exists():
        # In case raw itself is the LibriSpeech directory
        libri_root = raw
    assert libri_root.exists(), f"LibriSpeech directory not found under: {raw}"

    # Create output dirs
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    audio_train_dir = public / "audio_train"
    audio_test_dir = public / "audio_test"
    audio_train_dir.mkdir(parents=True, exist_ok=True)
    audio_test_dir.mkdir(parents=True, exist_ok=True)

    # Collect examples
    examples: List[Dict[str, object]] = []
    speakers: Dict[str, List[int]] = {}

    for utt_id, transcript, flac_path, spk_id in _iter_transcripts(libri_root):
        idx = len(examples)
        examples.append({
            "utt_id": utt_id,
            "transcript": transcript,
            "flac_path": flac_path,
            "spk_id": spk_id,
        })
        speakers.setdefault(spk_id, []).append(idx)

    # Minimal sanity
    assert len(examples) > 0, "No examples found in LibriSpeech raw directory."

    # Speaker-disjoint split (85/15 by speaker)
    spk_ids = list(speakers.keys())
    random.Random(RANDOM_SEED).shuffle(spk_ids)
    split_point = max(1, int(0.85 * len(spk_ids)))
    train_spk = set(spk_ids[:split_point])
    test_spk = set(spk_ids[split_point:])
    assert train_spk.isdisjoint(test_spk)

    train_idx, test_idx = [], []
    for i, ex in enumerate(examples):
        (train_idx if ex["spk_id"] in train_spk else test_idx).append(i)

    # Map new ids and copy audio
    used_ids: set[str] = set()

    def _new_id() -> str:
        nid = uuid.uuid4().hex[:16]
        while nid in used_ids:
            nid = uuid.uuid4().hex[:16]
        used_ids.add(nid)
        return nid

    train_rows: List[Dict[str, str]] = []
    test_rows: List[Dict[str, str]] = []
    ans_rows: List[Dict[str, str]] = []

    for idx in train_idx:
        ex = examples[idx]
        nid = _new_id()
        rel_audio = Path("audio_train") / f"{nid}.flac"
        _safe_copy(ex["flac_path"], public / rel_audio)
        train_rows.append({
            "id": nid,
            "audio_path": str(rel_audio),
            "transcript": str(ex["transcript"]),
        })

    for idx in test_idx:
        ex = examples[idx]
        nid = _new_id()
        rel_audio = Path("audio_test") / f"{nid}.flac"
        _safe_copy(ex["flac_path"], public / rel_audio)
        test_rows.append({
            "id": nid,
            "audio_path": str(rel_audio),
        })
        ans_rows.append({
            "id": nid,
            "transcript": str(ex["transcript"]),
        })

    # Sort for alignment
    train_rows.sort(key=lambda r: r["id"])  # stable
    paired = sorted(zip(test_rows, ans_rows), key=lambda p: p[0]["id"])
    test_rows = [p[0] for p in paired]
    ans_rows = [p[1] for p in paired]

    # Write CSVs
    def _write_csv(path: Path, fieldnames: List[str], rows: List[Dict[str, str]]):
        with path.open("w", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader()
            for r in rows:
                w.writerow(r)

    _write_csv(public / "train.csv", ["id", "audio_path", "transcript"], train_rows)
    _write_csv(public / "test.csv", ["id", "audio_path"], test_rows)
    _write_csv(private / "test_answer.csv", ["id", "transcript"], ans_rows)

    # Sample submission: pick random transcript from training set
    tr_texts = [_normalize_text(r["transcript"]) for r in train_rows]
    tr_texts = [t for t in tr_texts if t]
    sample_rows = [{"id": r["id"], "transcript": random.choice(tr_texts) if tr_texts else ""} for r in test_rows]
    _write_csv(public / "sample_submission.csv", ["id", "transcript"], sample_rows)

    # Copy description.txt into public directory if present at root working dir
    root_desc = Path(__file__).parent / "description.txt"
    if root_desc.exists():
        with root_desc.open("r", encoding="utf-8") as src, (public / "description.txt").open("w", encoding="utf-8") as dst:
            dst.write(src.read())

    # Checks
    # - existence
    assert (public / "train.csv").is_file(), "public/train.csv should exist"
    assert (public / "test.csv").is_file(), "public/test.csv should exist"
    assert (private / "test_answer.csv").is_file(), "private/test_answer.csv should exist"
    assert (public / "sample_submission.csv").is_file(), "public/sample_submission.csv should exist"
    # - audio files exist
    for r in train_rows:
        ap = public / r["audio_path"]
        assert ap.exists(), f"Missing audio: {ap}"
    for r in test_rows:
        ap = public / r["audio_path"]
        assert ap.exists(), f"Missing audio: {ap}"
    # - ids disjoint
    train_ids = {r["id"] for r in train_rows}
    test_ids = {r["id"] for r in test_rows}
    assert train_ids.isdisjoint(test_ids), "Train and test ids should not overlap"
    # - alignment
    assert [r["id"] for r in ans_rows] == [r["id"] for r in test_rows], "Answer and test ids must align"
    # - sample ids
    assert [r["id"] for r in sample_rows] == [r["id"] for r in test_rows], "Sample submission ids must match test"
    # - no leakage: ensure no .txt in audio dirs
    for folder in [audio_train_dir, audio_test_dir]:
        for p in folder.rglob("*"):
            if p.is_file():
                assert not p.suffix.lower().endswith(".txt"), f"Leakage: found text file {p} in audio folder"

    # simple transcript non-empty check (after normalization)
    assert all(_normalize_text(r["transcript"]) for r in train_rows), "Empty transcript in train"
    assert all(_normalize_text(r["transcript"]) for r in ans_rows), "Empty transcript in answers"