import csv
import os
import random
import shutil
from pathlib import Path
from typing import Dict, List, Tuple


def _link_or_copy(src: Path, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        # Try hardlink for speed and space efficiency
        os.link(src, dst)
    except Exception:
        # Fallback to copy if linking fails (e.g., cross-filesystem)
        shutil.copy2(src, dst)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete deterministic preparation process for the Bundesliga video classification task.

    Args (must be absolute Paths):
    - raw: absolute path to raw/ directory containing:
        raw/DFL Bundesliga Data Shootout/train/<class_dir>/*.mp4
        (raw/test may exist but is not used)
    - public: absolute path where public artifacts will be written
    - private: absolute path where private artifacts will be written

    This function will create the following under public/ and private/:
    - public/
        - train.csv (columns: id,label)
        - test.csv (columns: id)
        - sample_submission.csv (columns: id,label)
        - train_videos/ (anonymized video files)
        - test_videos/ (anonymized video files)
        - description.txt (copied from repository root)
    - private/
        - test_answer.csv (columns: id,label)
    """

    # Preconditions
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Paths must be absolute"

    # Ensure base dirs exist
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Source dataset root (original files)
    src_root = raw / "DFL Bundesliga Data Shootout" / "train"
    assert src_root.is_dir(), f"Source root not found: {src_root}"

    # Output directories
    train_vid_dir = public / "train_videos"
    test_vid_dir = public / "test_videos"

    # Clean old outputs
    if train_vid_dir.exists():
        shutil.rmtree(train_vid_dir)
    if test_vid_dir.exists():
        shutil.rmtree(test_vid_dir)
    train_vid_dir.mkdir(parents=True, exist_ok=True)
    test_vid_dir.mkdir(parents=True, exist_ok=True)

    for f in [public / "train.csv", public / "test.csv", public / "sample_submission.csv", private / "test_answer.csv"]:
        if f.exists():
            f.unlink()

    # Deterministic seed
    SEED = 1337
    rng = random.Random(SEED)

    # Discover class directories
    class_dirs = sorted([p for p in src_root.iterdir() if p.is_dir()])
    assert len(class_dirs) >= 2, "Expected at least two class folders under train/"

    # Derive anonymized class labels as first letter uppercase, fallback to enumerated
    def class_to_label(name: str, idx: int) -> str:
        for c in name:
            if c.isalpha():
                return c.upper()
        return f"C{idx}"

    # Gather records: (path, class_label)
    records: List[Tuple[Path, str]] = []
    for idx, cdir in enumerate(class_dirs):
        videos = sorted([p for p in cdir.iterdir() if p.suffix.lower() == ".mp4"])
        assert len(videos) >= 2, f"Class {cdir.name} must contain at least two videos"
        label = class_to_label(cdir.name, idx)
        for vp in videos:
            records.append((vp, label))

    # Stratified split by label
    by_label: Dict[str, List[Path]] = {}
    for vp, lbl in records:
        by_label.setdefault(lbl, []).append(vp)

    train_pairs: List[Tuple[Path, str]] = []
    test_pairs: List[Tuple[Path, str]] = []

    for lbl, paths in by_label.items():
        paths = paths[:]
        rng.shuffle(paths)
        n = len(paths)
        k = max(1, int(round(0.2 * n)))
        if n - k < 1:
            k = n - 1
        test_paths = paths[:k]
        train_paths = paths[k:]
        train_pairs.extend([(p, lbl) for p in train_paths])
        test_pairs.extend([(p, lbl) for p in test_paths])

    # Assign anonymized ids and link/copy
    next_id = 0

    def new_id() -> str:
        nonlocal next_id
        next_id += 1
        return f"vid_{next_id:06d}.mp4"

    train_rows: List[Tuple[str, str]] = []
    test_rows: List[Tuple[str, str]] = []

    for p, lbl in train_pairs:
        nid = new_id()
        _link_or_copy(p, train_vid_dir / nid)
        train_rows.append((nid, lbl))

    for p, lbl in test_pairs:
        nid = new_id()
        _link_or_copy(p, test_vid_dir / nid)
        test_rows.append((nid, lbl))

    # Write CSVs
    with (public / "train.csv").open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for nid, lbl in train_rows:
            w.writerow([nid, lbl])

    with (public / "test.csv").open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id"])
        for nid, _ in test_rows:
            w.writerow([nid])

    with (private / "test_answer.csv").open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for nid, lbl in test_rows:
            w.writerow([nid, lbl])

    # Sample submission: choose random valid labels from observed set
    labels = sorted(set(lbl for _, lbl in test_rows))
    with (public / "sample_submission.csv").open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for nid, _ in test_rows:
            w.writerow([nid, rng.choice(labels)])

    # Copy description.txt into public/
    repo_desc = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "description.txt")))
    if repo_desc.exists():
        shutil.copy2(repo_desc, public / "description.txt")

    # Checks
    assert (public / "train_videos").exists(), "public/train_videos should exist"
    assert (public / "test_videos").exists(), "public/test_videos should exist"

    # Check id uniqueness and alignment
    tr_ids = [nid for nid, _ in train_rows]
    te_ids = [nid for nid, _ in test_rows]
    assert len(tr_ids) == len(set(tr_ids)), "Duplicate ids in train.csv"
    assert len(te_ids) == len(set(te_ids)), "Duplicate ids in test_answer.csv"

    # Read test.csv to check alignment
    test_csv_ids = []
    with (public / "test.csv").open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            test_csv_ids.append(row["id"])
    
    assert set(te_ids) == set(test_csv_ids), "Mismatch between test.csv and test_answer.csv ids"

    # Ensure anonymized id shape
    for nid in tr_ids[:5] + te_ids[:5]:
        assert nid.lower().startswith("vid_") and nid.lower().endswith(".mp4"), "IDs must be anonymized vid_*.mp4"

    # Basic sanity checks
    assert len(tr_ids) > 0 and len(te_ids) > 0, "Non-empty splits required"
