from pathlib import Path
import shutil
import random
from collections import defaultdict
import csv
import pandas as pd

# Deterministic behavior
random.seed(42)


def _read_metadata(csv_path: Path) -> list[dict]:
    assert csv_path.is_absolute() and csv_path.exists(), f"Metadata CSV not found: {csv_path}"
    rows: list[dict] = []
    with csv_path.open("r", newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        expected_cols = ["filename", "callsign", "sha256_checksum"]
        assert reader.fieldnames is not None, "CSV has no header"
        assert all(c in reader.fieldnames for c in expected_cols), (
            f"CSV must contain columns {expected_cols}, found {reader.fieldnames}"
        )
        for r in reader:
            fn = (r.get("filename") or "").strip()
            cs = (r.get("callsign") or "").strip()
            sha = (r.get("sha256_checksum") or "").strip()
            if not fn or not cs:
                continue
            rows.append({"filename": fn, "callsign": cs, "sha256_checksum": sha})
    assert len(rows) > 0, "No metadata rows loaded"
    return rows


def _gather_available_audio(audio_dir: Path) -> set[str]:
    assert audio_dir.is_absolute() and audio_dir.is_dir(), f"Audio directory not found: {audio_dir}"
    files = [p.name for p in audio_dir.iterdir() if p.is_file() and p.suffix.lower() == ".wav"]
    assert len(files) == len(set(files)), "Duplicate filenames in audio directory"
    return set(files)


def _build_dataset(rows: list[dict], available_files: set[str]) -> list[dict]:
    filtered = [r for r in rows if r["filename"] in available_files]
    assert len(filtered) > 0, "No overlapping audio files between CSV and directory"
    return filtered


def _make_id_mapping(examples: list[dict]) -> dict[str, str]:
    examples_sorted = sorted(examples, key=lambda r: (r["filename"], r["sha256_checksum"]))
    mapping: dict[str, str] = {}
    for idx, r in enumerate(examples_sorted, start=1):
        clip_id = f"clip_{idx:06d}"
        mapping[r["filename"]] = clip_id
    assert len(mapping) == len(examples_sorted)
    assert len(set(mapping.values())) == len(mapping)
    return mapping


def _stratified_split(examples: list[dict], test_ratio: float) -> tuple[list[dict], list[dict]]:
    by_label: dict[str, list[dict]] = defaultdict(list)
    for r in examples:
        by_label[r["callsign"]].append(r)

    train: list[dict] = []
    test: list[dict] = []
    for label, lst in by_label.items():
        lst_sorted = sorted(lst, key=lambda r: (r["filename"], r["sha256_checksum"]))
        n = len(lst_sorted)
        if n == 1:
            test_n = 0
        else:
            test_n = int(n * test_ratio)
            if test_n < 1:
                test_n = 1
            if n - test_n < 1:
                test_n = max(0, n - 1)
        train.extend(lst_sorted[: n - test_n])
        test.extend(lst_sorted[n - test_n :])

    random.shuffle(train)
    random.shuffle(test)

    train_set = set((r["filename"], r["sha256_checksum"]) for r in train)
    test_set = set((r["filename"], r["sha256_checksum"]) for r in test)
    assert train_set.isdisjoint(test_set), "Train/Test overlap detected"

    train_labels = {r["callsign"] for r in train}
    test_labels = {r["callsign"] for r in test}
    missing = sorted(list(test_labels - train_labels))
    assert len(missing) == 0, f"Some test labels are not present in train: {missing}"

    return train, test


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete deterministic preparation process.

    - Reads raw metadata and audio files
    - Creates deterministic id mapping and stratified split
    - Writes public/train.csv, public/test.csv, public/sample_submission.csv
    - Copies audio into public/train_audio and public/test_audio with renamed ids
    - Writes private/test_answer.csv
    - Copies description.txt into public/
    """

    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Paths must be absolute"

    # Source paths inside raw
    meta_csv = raw / "Amateur_Radio_2-Meter_FM_Simplex_Transmissions-KAGGLE-Sorted-Hashed.csv"
    audio_src = raw / "Amateur_Radio_Transmissions-2-Meter_FM_Simplex" / "dataset_audio_waveforms"

    # Ensure output directories exist and are clean
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)
    train_audio_dir = public / "train_audio"
    test_audio_dir = public / "test_audio"
    for d in [train_audio_dir, test_audio_dir]:
        if d.exists():
            for p in d.glob("*"):
                if p.is_file():
                    p.unlink()
        else:
            d.mkdir(parents=True, exist_ok=True)

    # Load and filter dataset
    rows = _read_metadata(meta_csv)
    available = _gather_available_audio(audio_src)
    dataset = _build_dataset(rows, available)

    # Mapping and split
    id_map = _make_id_mapping(dataset)
    train, test = _stratified_split(dataset, test_ratio=0.20)

    # Write CSV files
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    with train_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "callsign"])  # schema only
        for r in train:
            w.writerow([id_map[r["filename"]], r["callsign"]])

    with test_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id"])  # schema only
        for r in test:
            w.writerow([id_map[r["filename"]]])

    with test_answer_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "callsign"])  # hidden labels
        for r in test:
            w.writerow([id_map[r["filename"]], r["callsign"]])

    # sample_submission uses train label set to ensure valid labels, deterministically
    train_labels = [r["callsign"] for r in train]
    assert len(train_labels) > 0
    rng = random.Random(0)
    with sample_sub_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "callsign"])  # required format
        for r in test:
            rid = id_map[r["filename"]]
            pred = rng.choice(train_labels)
            w.writerow([rid, pred])

    # Copy audio into public directories with renamed ids
    for subset, out_dir in [(train, train_audio_dir), (test, test_audio_dir)]:
        for r in subset:
            src_path = audio_src / r["filename"]
            rid = id_map[r["filename"]]
            dst_path = out_dir / f"{rid}.wav"
            shutil.copyfile(src_path, dst_path)
            assert dst_path.exists(), f"Failed to copy to {dst_path}"

    # Copy description.txt into public
    root_desc = Path(__file__).resolve().parent / "description.txt"
    if root_desc.exists():
        shutil.copyfile(root_desc, public / "description.txt")

    # Comprehensive checks
    # Headers
    assert pd.read_csv(train_csv).columns.tolist() == ["id", "callsign"]
    assert pd.read_csv(test_csv).columns.tolist() == ["id"]
    assert pd.read_csv(test_answer_csv).columns.tolist() == ["id", "callsign"]
    assert pd.read_csv(sample_sub_csv).columns.tolist() == ["id", "callsign"]

    # Shapes and disjointness
    train_df = pd.read_csv(train_csv)
    test_df = pd.read_csv(test_csv)
    ans_df = pd.read_csv(test_answer_csv)
    samp_df = pd.read_csv(sample_sub_csv)

    assert len(test_df) == len(ans_df) == len(samp_df), "test, test_answer, and sample_submission must align"
    assert set(test_df["id"]) == set(ans_df["id"]) == set(samp_df["id"])  # exact ids
    assert set(train_df["id"]).isdisjoint(set(test_df["id"]))

    # No leakage: description is only in public, answers only in private
    assert (public / "description.txt").exists(), "description.txt must be present in public/"

    # Labels: every test label must be seen in train labels
    assert set(ans_df["callsign"]).issubset(set(train_df["callsign"]))

    # Files exist for each id
    for rid in train_df["id"].tolist():
        assert (train_audio_dir / f"{rid}.wav").exists()
    for rid in test_df["id"].tolist():
        assert (test_audio_dir / f"{rid}.wav").exists()
