from pathlib import Path
import csv
import hashlib
import random
import shutil
from collections import defaultdict, Counter


def prepare(raw: Path, public: Path, private: Path):
    """Prepare the Kinetics subset competition data.

    Inputs
    - raw: absolute path to raw/ where the original data is placed (read-only)
    - public: absolute path to public/ where all participant-visible files will be written
    - private: absolute path to private/ where hidden ground-truth will be written

    Outputs in public/
    - train_videos/ (mp4 files)
    - test_videos/ (mp4 files)
    - train.csv (video_id, filepath, label)
    - test.csv (video_id, filepath)
    - sample_submission.csv (video_id, label)
    - description.txt (copied from repository root)

    Outputs in private/
    - test_answer.csv (video_id, label)
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), (
        "raw, public, private must be absolute paths"
    )

    # Clean/create output directories
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Detect source dataset root having mp4 files under class subfolders
    candidates = [
        raw / "kinetics400_5per" / "kinetics400_5per" / "train",
        raw / "kinetics600_5per" / "kinetics600_5per" / "train",
    ]
    src_base = None
    for cand in candidates:
        if cand.exists() and cand.is_dir():
            src_base = cand
            break
    assert src_base is not None, (
        f"Could not find dataset under {candidates[0]} or {candidates[1]}"
    )

    # Collect classes and files
    class_to_files: dict[str, list[Path]] = defaultdict(list)
    for cls_dir in sorted(p for p in src_base.iterdir() if p.is_dir()):
        mp4s = sorted([p for p in cls_dir.iterdir() if p.is_file() and p.suffix.lower() == ".mp4"])
        if mp4s:
            class_to_files[cls_dir.name].extend(mp4s)
    classes = sorted(class_to_files.keys())
    total_files = sum(len(v) for v in class_to_files.values())
    assert len(classes) >= 2, f"Expected at least 2 classes, found {len(classes)}"
    assert total_files > 0, "No mp4 files found in source dataset"

    # Stratified split per class (80/20), ensuring each class in both splits when possible
    rng = random.Random(2023)
    train_items: list[tuple[Path, str]] = []
    test_items: list[tuple[Path, str]] = []
    for cls, files in class_to_files.items():
        n = len(files)
        assert n >= 2, f"Class '{cls}' must have at least 2 videos to split, found {n}"
        k = max(1, int(round(n * 0.20)))
        if k >= n:
            k = n - 1
        idxs = list(range(n))
        rng.shuffle(idxs)
        test_idx = set(idxs[:k])
        for i, f in enumerate(files):
            if i in test_idx:
                test_items.append((f, cls))
            else:
                train_items.append((f, cls))

    # Ensure every class present in both splits
    tr_cls = {c for _, c in train_items}
    te_cls = {c for _, c in test_items}
    assert set(classes) == tr_cls | te_cls, "Some classes are missing after split"
    assert set(classes) <= tr_cls and set(classes) <= te_cls, (
        "Every class should appear in both train and test"
    )

    # Destination folders
    train_dir = public / "train_videos"
    test_dir = public / "test_videos"
    train_dir.mkdir(parents=True, exist_ok=True)
    test_dir.mkdir(parents=True, exist_ok=True)

    # Helper: hard-link or copy
    def hardlink_or_copy(src: Path, dst: Path):
        dst.parent.mkdir(parents=True, exist_ok=True)
        try:
            if dst.exists():
                dst.unlink()
            # Hardlink when possible (saves disk usage), fall back to copy
            src.link_to(dst)  # This will raise; Path.link_to links dst -> src path style differs
        except Exception:
            try:
                import os as _os
                _os.link(src, dst)
            except Exception:
                shutil.copy2(src, dst)

    # Anonymize filenames to avoid label leakage
    def sha1_token(s: str, n: int = 20) -> str:
        return hashlib.sha1(s.encode("utf-8")).hexdigest()[:n]

    salt = sha1_token("kinetics-salt-2023", 40)

    def make_vid(src_path: Path, cls: str, idx: int) -> str:
        base = f"{src_path.name}|{cls}|{idx}|{salt}"
        return sha1_token(base, 20)

    rows_train = []
    rows_test = []
    rows_ans = []
    used_ids = set()

    for i, (src, cls) in enumerate(train_items):
        vid = make_vid(src, cls, i)
        while vid in used_ids:
            vid = make_vid(src, cls, i + 1)
        used_ids.add(vid)
        dst = train_dir / f"{vid}.mp4"
        hardlink_or_copy(src, dst)
        rows_train.append({
            "video_id": vid,
            "filepath": str(dst.relative_to(public)),  # train_videos/xxx.mp4
            "label": cls,
        })

    for i, (src, cls) in enumerate(test_items):
        vid = make_vid(src, cls, i + 1_000_000)
        while vid in used_ids:
            vid = make_vid(src, cls, i + 1_000_001)
        used_ids.add(vid)
        dst = test_dir / f"{vid}.mp4"
        hardlink_or_copy(src, dst)
        rows_test.append({
            "video_id": vid,
            "filepath": str(dst.relative_to(public)),  # test_videos/xxx.mp4
        })
        rows_ans.append({"video_id": vid, "label": cls})

    # Write CSV helpers
    def write_csv(path: Path, fieldnames: list[str], rows: list[dict]):
        with path.open("w", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader()
            for r in rows:
                w.writerow(r)

    # Write public and private CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    write_csv(train_csv, ["video_id", "filepath", "label"], rows_train)
    write_csv(test_csv, ["video_id", "filepath"], rows_test)
    write_csv(test_answer_csv, ["video_id", "label"], rows_ans)

    # sample_submission: fixed label chosen randomly but valid from training labels
    train_labels = sorted({r["label"] for r in rows_train})
    rng2 = random.Random(2023 + 7)
    sample_rows = [{"video_id": r["video_id"], "label": rng2.choice(train_labels)} for r in rows_test]
    sample_csv = public / "sample_submission.csv"
    write_csv(sample_csv, ["video_id", "label"], sample_rows)

    # Ensure sample_submission order equals test.csv order
    def read_ids(path: Path) -> list[str]:
        with path.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            return [row["video_id"] for row in reader]

    assert read_ids(test_csv) == read_ids(sample_csv), (
        "sample_submission video_id order must match test.csv"
    )

    # Copy description.txt into public directory (visible to participants)
    repo_desc = raw.parent / "description.txt"
    if repo_desc.exists():
        shutil.copy2(repo_desc, public / "description.txt")

    # Sanity checks
    # 1) Files referenced in CSV exist
    for r in rows_train:
        assert (public / r["filepath"]).exists(), f"Missing train file: {r['filepath']}"
    for r in rows_test:
        assert (public / r["filepath"]).exists(), f"Missing test file: {r['filepath']}"

    # 2) Uniqueness and disjointness of IDs
    train_ids = [r["video_id"] for r in rows_train]
    test_ids = [r["video_id"] for r in rows_test]
    assert len(train_ids) == len(set(train_ids)), "Duplicate video_id in train"
    assert len(test_ids) == len(set(test_ids)), "Duplicate video_id in test"
    assert set(train_ids).isdisjoint(test_ids), "Overlap between train and test IDs"

    # 3) No label leakage in filenames
    leak = []
    for lbl in train_labels:
        l = lbl.lower()
        for p in list(train_dir.glob("*.mp4")) + list(test_dir.glob("*.mp4")):
            if l in p.name.lower():
                leak.append((lbl, p.name))
                break
    assert not leak, f"Potential label leakage via filenames: {leak[:3]}"

    # 4) Column schema checks
    with train_csv.open("r", encoding="utf-8") as f:
        reader = csv.reader(f)
        header = next(reader)
        assert header == ["video_id", "filepath", "label"], "train.csv header mismatch"
    with test_csv.open("r", encoding="utf-8") as f:
        reader = csv.reader(f)
        header = next(reader)
        assert header == ["video_id", "filepath"], "test.csv header mismatch"
    with sample_csv.open("r", encoding="utf-8") as f:
        reader = csv.reader(f)
        header = next(reader)
        assert header == ["video_id", "label"], "sample_submission.csv header mismatch"
    with test_answer_csv.open("r", encoding="utf-8") as f:
        reader = csv.reader(f)
        header = next(reader)
        assert header == ["video_id", "label"], "test_answer.csv header mismatch"

    # 5) Minimal per-class presence in train
    min_per_class = min(Counter([r["label"] for r in rows_train]).values())
    assert min_per_class >= 1, "Each class must have at least 1 training sample"

    # End of preparation
    return None
