from pathlib import Path
import csv
import random
import shutil
from collections import defaultdict
from typing import List, Dict, Tuple


def _read_metadata(src_root: Path) -> List[dict]:
    wavs_dir = src_root / "LJSpeech-1.1" / "wavs"
    meta_path = src_root / "LJSpeech-1.1" / "metadata.csv"
    assert wavs_dir.is_dir(), f"Wavs directory not found: {wavs_dir}"
    assert meta_path.exists(), f"Metadata not found: {meta_path}"

    rows = []
    with meta_path.open("r", encoding="utf-8") as f:
        for line in f:
            parts = line.rstrip("\n").split("|")
            if len(parts) < 3:
                continue
            clip = parts[0].strip()
            transcript = parts[2].strip()
            wav_path = wavs_dir / f"{clip}.wav"
            if wav_path.exists():
                rows.append({"orig_id": clip, "transcript": transcript, "wav_path": wav_path})
    return rows


def _split_by_transcript(items: List[dict], test_fraction: float, seed: int = 1337) -> Tuple[List[dict], List[dict]]:
    groups: Dict[str, List[dict]] = defaultdict(list)
    for it in items:
        groups[it["transcript"]].append(it)
    keys = list(groups.keys())
    rng = random.Random(seed)
    rng.shuffle(keys)
    total = len(items)
    target_test = int(round(total * test_fraction))
    test_set: List[dict] = []
    train_set: List[dict] = []
    count_test = 0
    for k in keys:
        group = groups[k]
        if count_test < target_test:
            test_set.extend(group)
            count_test += len(group)
        else:
            train_set.extend(group)
    if len(test_set) == 0:  # ensure non-empty split
        # move one group
        k = keys[0]
        for it in groups[k]:
            if it in train_set:
                train_set.remove(it)
                test_set.append(it)
    return train_set, test_set


def _assign_new_ids(train_items: List[dict], test_items: List[dict]) -> Dict[str, Dict[str, str]]:
    all_items = [(it, "train") for it in train_items] + [(it, "test") for it in test_items]
    all_items.sort(key=lambda x: x[0]["orig_id"])  # stable deterministic sort
    mapping: Dict[str, Dict[str, str]] = {}
    for idx, (it, split) in enumerate(all_items):
        new_id = f"clip_{idx:06d}"
        mapping[it["orig_id"]] = {"new_id": new_id, "split": split}
    return mapping


def _write_outputs(
    train_items: List[dict],
    test_items: List[dict],
    id_map: Dict[str, Dict[str, str]],
    public: Path,
    private: Path,
) -> None:
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    train_wavs_out = public / "train_wavs"
    test_wavs_out = public / "test_wavs"
    train_wavs_out.mkdir(parents=True, exist_ok=True)
    test_wavs_out.mkdir(parents=True, exist_ok=True)

    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    with train_csv.open("w", encoding="utf-8", newline="") as ftr, \
        test_csv.open("w", encoding="utf-8", newline="") as ftd, \
        test_answer_csv.open("w", encoding="utf-8", newline="") as fta:
        tr_writer = csv.DictWriter(ftr, fieldnames=["clip_id", "audio_path", "transcript"])
        td_writer = csv.DictWriter(ftd, fieldnames=["clip_id", "audio_path"])
        ta_writer = csv.DictWriter(fta, fieldnames=["clip_id", "transcript"])
        tr_writer.writeheader()
        td_writer.writeheader()
        ta_writer.writeheader()

        for it in train_items:
            mid = id_map[it["orig_id"]]["new_id"]
            src = it["wav_path"]
            dst = train_wavs_out / f"{mid}.wav"
            shutil.copy2(src, dst)
            tr_writer.writerow({
                "clip_id": mid,
                "audio_path": f"train_wavs/{mid}.wav",
                "transcript": it["transcript"],
            })

        for it in test_items:
            mid = id_map[it["orig_id"]]["new_id"]
            src = it["wav_path"]
            dst = test_wavs_out / f"{mid}.wav"
            shutil.copy2(src, dst)
            td_writer.writerow({
                "clip_id": mid,
                "audio_path": f"test_wavs/{mid}.wav",
            })
            ta_writer.writerow({
                "clip_id": mid,
                "transcript": it["transcript"],
            })

    # sample submission uses test ids with placeholder text
    # Build a simple vocabulary from training transcripts
    vocab: List[str] = []
    with train_csv.open("r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            vocab.extend(str(row["transcript"]).lower().split())
    if not vocab:
        vocab = ["the", "a", "to", "and", "of", "in"]
    rng = random.Random(1337)
    with test_csv.open("r", encoding="utf-8", newline="") as fin, \
        sample_sub_csv.open("w", encoding="utf-8", newline="") as fout:
        reader = csv.DictReader(fin)
        writer = csv.DictWriter(fout, fieldnames=["clip_id", "transcript"])
        writer.writeheader()
        for row in reader:
            n = rng.randint(1, 8)
            toks = [rng.choice(vocab) for _ in range(n)]
            writer.writerow({"clip_id": row["clip_id"], "transcript": " ".join(toks)})

    # also copy description.txt into public
    root_desc = Path(__file__).with_name("description.txt")
    if root_desc.exists():
        shutil.copy2(root_desc, public / "description.txt")


def _run_checks(public: Path, private: Path) -> None:
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    sample_sub = public / "sample_submission.csv"
    test_ans = private / "test_answer.csv"
    assert train_csv.exists(), f"Missing {train_csv}"
    assert test_csv.exists(), f"Missing {test_csv}"
    assert sample_sub.exists(), f"Missing {sample_sub}"
    assert test_ans.exists(), f"Missing {test_ans}"

    # Load ids
    train_ids = set()
    test_ids = set()
    ans_ids = set()
    with train_csv.open("r", encoding="utf-8", newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            assert row["clip_id"].startswith("clip_")
            # Construct full path for existence check
            audio_full_path = public / row["audio_path"]
            assert audio_full_path.exists(), f"Missing audio: {audio_full_path}"
            train_ids.add(row["clip_id"])
    with test_csv.open("r", encoding="utf-8", newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            assert row["clip_id"].startswith("clip_")
            # Construct full path for existence check
            audio_full_path = public / row["audio_path"]
            assert audio_full_path.exists(), f"Missing audio: {audio_full_path}"
            test_ids.add(row["clip_id"])
    with test_ans.open("r", encoding="utf-8", newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            ans_ids.add(row["clip_id"])
    assert train_ids.isdisjoint(test_ids), "Train/Test ID overlap"
    assert test_ids == ans_ids, "Test ids and answer ids must match"


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process.
    - Read raw/LJSpeech-1.1 metadata and wavs
    - Split into train/test
    - Write public/train.csv, public/test.csv, public/train_wavs/*, public/test_wavs/*
    - Write private/test_answer.csv
    - Write public/sample_submission.csv
    - Copy description.txt into public
    """
    assert raw.is_dir(), f"Raw directory not found: {raw}"

    # Clean/create output dirs
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    items = _read_metadata(raw)
    assert len(items) > 1000, f"Unexpectedly few items: {len(items)}"

    train_items, test_items = _split_by_transcript(items, test_fraction=0.2, seed=1337)
    assert len(train_items) > 0 and len(test_items) > 0, "Empty split"

    id_map = _assign_new_ids(train_items, test_items)

    _write_outputs(train_items, test_items, id_map, public, private)

    _run_checks(public, private)