import os
import csv
import random
import shutil
from pathlib import Path
from typing import List, Tuple

# This module exposes exactly one function: prepare(raw: Path, public: Path, private: Path)
# It prepares the competition files into the provided public/ and private/ folders.

RANDOM_SEED = 42
random.seed(RANDOM_SEED)


def _safe_link(src: Path, dst: Path):
    """Create a hardlink if possible, else copy. Ensure dst parent exists.
    Overwrite if an old file exists (idempotent).
    """
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        if dst.exists() or dst.is_symlink():
            dst.unlink()
        os.link(src, dst)
    except OSError:
        shutil.copy2(src, dst)


def _collect_videos(root_dir: Path) -> List[Tuple[Path, str]]:
    """Collect all .avi videos under class subfolders.

    Returns a list of tuples (abs_path, class_label), where class_label is the
    immediate subfolder name under root_dir.
    """
    videos: List[Tuple[Path, str]] = []
    if not root_dir.is_dir():
        return videos
    for cls in sorted([p for p in root_dir.iterdir() if p.is_dir()]):
        for f in sorted(cls.iterdir()):
            if f.is_file() and f.suffix.lower() == ".avi":
                videos.append((f.resolve(), cls.name))
    return videos


def prepare(raw: Path, public: Path, private: Path):
    # 1) Discover dataset in raw/
    src_train = raw / "train"
    src_val = raw / "val"
    src_test = raw / "test"

    train_videos = _collect_videos(src_train)
    val_videos = _collect_videos(src_val)
    test_videos = _collect_videos(src_test)

    assert len(train_videos) > 0, f"No videos found in {src_train}"
    assert len(val_videos) > 0, f"No videos found in {src_val}"
    assert len(test_videos) > 0, f"No videos found in {src_test}"

    # Merge train+val into training pool
    merged_train = train_videos + val_videos

    # Derive label set from merged train
    label_set = sorted({lbl for _, lbl in merged_train})
    assert len(label_set) >= 2, "Expected multiple classes in training set."

    # Ensure test labels are subset of train labels
    test_label_set = sorted({lbl for _, lbl in test_videos})
    missing_in_train = set(test_label_set) - set(label_set)
    assert not missing_in_train, f"Test contains unseen labels: {missing_in_train}"

    # 2) Prepare output directories (under public/ and private/)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    out_train_dir = public / "train_videos"
    out_test_dir = public / "test_videos"
    out_train_dir.mkdir(parents=True, exist_ok=True)
    out_test_dir.mkdir(parents=True, exist_ok=True)

    # 3) Create anonymized filenames and write CSVs
    # public/train.csv: id,file_name,label
    # public/test.csv: id,file_name
    # private/test_answer.csv: id,label

    train_rows = []
    for i, (src_path, lbl) in enumerate(merged_train, start=1):
        new_name = f"tr_{i:06d}.avi"
        _safe_link(src_path, out_train_dir / new_name)
        train_rows.append({"id": i, "file_name": new_name, "label": lbl})

    test_rows = []
    for i, (src_path, lbl) in enumerate(test_videos, start=1):
        new_name = f"te_{i:06d}.avi"
        _safe_link(src_path, out_test_dir / new_name)
        test_rows.append({"id": i, "file_name": new_name, "label": lbl})  # label only for private answer

    # Write CSVs
    public_train_csv = public / "train.csv"
    public_test_csv = public / "test.csv"
    private_answer_csv = private / "test_answer.csv"
    public_sample_sub_csv = public / "sample_submission.csv"

    with public_train_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "file_name", "label"])
        for r in train_rows:
            w.writerow([r["id"], r["file_name"], r["label"]])

    with public_test_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "file_name"])
        for r in test_rows:
            w.writerow([r["id"], r["file_name"]])

    with private_answer_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])  # internal answer key
        for r in test_rows:
            w.writerow([r["id"], r["label"]])

    # Sample submission with random labels drawn from training label set
    with public_sample_sub_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])  # submission schema
        for r in test_rows:
            w.writerow([r["id"], random.choice(label_set)])

    # Also place a copy of description.txt into public/
    repo_root = Path(__file__).resolve().parent
    desc_src = repo_root / "description.txt"
    if desc_src.exists():
        shutil.copy2(desc_src, public / "description.txt")

    # 4) Assertions and checks
    # Basic counts
    assert sum(1 for _ in public_train_csv.open("r", encoding="utf-8")) - 1 == len(train_rows), "Mismatch in public/train.csv rows"
    assert sum(1 for _ in public_test_csv.open("r", encoding="utf-8")) - 1 == len(test_rows), "Mismatch in public/test.csv rows"

    # Uniqueness checks
    train_ids = [r["id"] for r in train_rows]
    test_ids = [r["id"] for r in test_rows]
    assert len(train_ids) == len(set(train_ids)), "Duplicate ids in train"
    assert len(test_ids) == len(set(test_ids)), "Duplicate ids in test"

    train_files = [r["file_name"] for r in train_rows]
    test_files = [r["file_name"] for r in test_rows]
    assert len(train_files) == len(set(train_files)), "Duplicate file names in train"
    assert len(test_files) == len(set(test_files)), "Duplicate file names in test"

    # Ensure outputs exist and point to files
    for fname in train_files:
        p = out_train_dir / fname
        assert p.is_file(), f"Missing train video file: {p}"
    for fname in test_files:
        p = out_test_dir / fname
        assert p.is_file(), f"Missing test video file: {p}"

    # Consistency between public/test.csv and private/test_answer.csv
    with public_test_csv.open("r", encoding="utf-8") as f_t, private_answer_csv.open("r", encoding="utf-8") as f_a:
        rt = list(csv.DictReader(f_t))
        ra = list(csv.DictReader(f_a))
        assert len(rt) == len(ra), "test.csv and test_answer.csv length mismatch"
        ids_t = [int(row['id']) for row in rt]
        ids_a = [int(row['id']) for row in ra]
        assert ids_t == ids_a, "Order or ids mismatch between test.csv and test_answer.csv"

    # Non-leakage: anonymized names do not contain label text
    def _non_leaking(name: str) -> bool:
        low = name.lower()
        return all(lbl.lower() not in low for lbl in label_set)

    for fname in train_files + test_files:
        assert _non_leaking(fname), f"Filename leaks label: {fname}"

    # Label coverage check
    labels_in_train = sorted({r["label"] for r in train_rows})
    assert set(labels_in_train) == set(label_set), "Label set mismatch in train"
