import shutil
from pathlib import Path
from collections import defaultdict, Counter
import random
import csv

LABELS = ["CC", "EC", "HGSC", "LGSC", "MC"]
SEED = 20230919
TEST_FRACTION = 0.20  # 20% test split


def _gather_source_images(raw: Path):
    sources = []
    for root_name in ["Train_Images", "Test_Images"]:
        root = raw / root_name
        if not root.exists():
            continue
        for label in sorted(LABELS):
            class_dir = root / label
            if not class_dir.exists():
                continue
            for fn in sorted(class_dir.glob("*.png")):
                sources.append((label, fn))
    if not sources:
        raise RuntimeError("No PNG images found in raw/Train_Images or raw/Test_Images under labeled subfolders.")
    return sources


essential_public_files = [
    "train.csv",
    "test.csv",
    "sample_submission.csv",
]


def _make_clean_dir(p: Path):
    if p.exists():
        shutil.rmtree(p)
    p.mkdir(parents=True, exist_ok=True)


def _safe_write_csv(path: Path, header, rows):
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(header)
        for r in rows:
            writer.writerow(r)


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the Ovarian Cancer Subtype dataset split with deterministic behavior.

    - raw: absolute path to raw directory containing Train_Images/ and Test_Images/ each with label subfolders.
    - public: absolute path to output public directory; will be created/overwritten.
    - private: absolute path to output private directory; will be created/overwritten.

    Outputs in public/:
      - train.csv (id,label)
      - test.csv (id)
      - train_images/ (image files)
      - test_images/ (image files)
      - sample_submission.csv (id,label)
      - description.txt (copied from project root if exists)

    Outputs in private/:
      - test_answer.csv (id,label)
    """
    # Deterministic RNG
    rng = random.Random(SEED)

    # Create/clean target directories
    _make_clean_dir(public)
    _make_clean_dir(private)
    train_dir = public / "train_images"
    test_dir = public / "test_images"
    train_dir.mkdir(parents=True, exist_ok=True)
    test_dir.mkdir(parents=True, exist_ok=True)

    # Gather labeled sources from raw
    sources = _gather_source_images(raw)

    # Group by label deterministically
    by_label = defaultdict(list)
    for label, path in sources:
        by_label[label].append(path)

    # Ensure all labels have at least one sample
    missing = [l for l in LABELS if len(by_label[l]) == 0]
    assert not missing, f"Some labels have no images: {missing}"

    # Stratified deterministic split per label
    train_items = []  # tuples (label, src_path)
    test_items = []
    for label in LABELS:
        items = sorted(by_label[label], key=lambda p: str(p))
        items_shuffled = items[:]
        rng.shuffle(items_shuffled)
        n = len(items_shuffled)
        n_test = max(1, int(round(n * TEST_FRACTION)))
        if n_test >= n:
            n_test = n - 1
        test_subset = items_shuffled[:n_test]
        train_subset = items_shuffled[n_test:]
        train_items.extend([(label, p) for p in train_subset])
        test_items.extend([(label, p) for p in test_subset])

    # Assign new ids deterministically and copy files
    def make_id(idx: int) -> str:
        return f"image_{idx:06d}.png"

    next_id = 1
    used_ids = set()

    train_rows = []  # [id,label]
    test_rows = []   # [id]
    test_ans_rows = []  # [id,label]

    # Deterministic order before assignment
    train_items_sorted = sorted(train_items, key=lambda x: (x[0], str(x[1])))
    test_items_sorted = sorted(test_items, key=lambda x: (x[0], str(x[1])))

    def copy_and_register(target_root: Path, src_path: Path) -> str:
        nonlocal next_id
        new_id = make_id(next_id)
        next_id += 1
        assert new_id not in used_ids
        used_ids.add(new_id)
        shutil.copy2(src_path, target_root / new_id)
        return new_id

    for label, src in train_items_sorted:
        new_id = copy_and_register(train_dir, src)
        train_rows.append([new_id, label])

    for label, src in test_items_sorted:
        new_id = copy_and_register(test_dir, src)
        test_rows.append([new_id])
        test_ans_rows.append([new_id, label])

    # Write csvs
    _safe_write_csv(public / "train.csv", ["id", "label"], train_rows)
    _safe_write_csv(public / "test.csv", ["id"], test_rows)
    _safe_write_csv(private / "test_answer.csv", ["id", "label"], test_ans_rows)

    # Sample submission (deterministic random labels)
    sample_rows = []
    for rid, in test_rows:
        sample_rows.append([rid, rng.choice(LABELS)])
    _safe_write_csv(public / "sample_submission.csv", ["id", "label"], sample_rows)

    # Copy description.txt into public if available at project root
    project_root = Path(__file__).resolve().parent
    desc_src = project_root / "description.txt"
    if desc_src.exists():
        shutil.copy2(desc_src, public / "description.txt")

    # Checks
    assert len(used_ids) == len(train_rows) + len(test_rows), "Duplicate ids detected"
    assert len(list(train_dir.glob("*.png"))) == len(train_rows), "Mismatch train images vs train.csv"
    assert len(list(test_dir.glob("*.png"))) == len(test_rows), "Mismatch test images vs test.csv"

    # Ensure identical order between test.csv and private/test_answer.csv ids
    test_ids = [rid for rid, in test_rows]
    ans_ids = [rid for rid, _ in test_ans_rows]
    assert test_ids == ans_ids, "test.csv and test_answer.csv id order must match"

    # Class coverage
    train_labels = [lbl for _, lbl in train_rows]
    test_labels = [lbl for _, lbl in test_ans_rows]
    train_counts = Counter(train_labels)
    test_counts = Counter(test_labels)
    assert all(train_counts[l] > 0 for l in LABELS), f"Zero samples in train for some label: {train_counts}"
    assert all(test_counts[l] > 0 for l in LABELS), f"Zero samples in test for some label: {test_counts}"

    # Ensure no accidental label tokens in filenames (ids)
    for rid in test_ids + [rid for rid, _ in train_rows]:
        stem = Path(rid).stem.lower()
        for token in [t.lower() for t in LABELS]:
            assert token not in stem, f"Label token '{token}' found in filename '{stem}'"
