from pathlib import Path
import os
import csv
import shutil
import random
from collections import defaultdict
from typing import Dict, List, Tuple, Set

# Deterministic behavior for the split process
random.seed(42)


# ------------------------------
# Helpers
# ------------------------------

def _read_synset_mapping(path: Path) -> Dict[str, str]:
    mapping: Dict[str, str] = {}
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            name, code = parts
            mapping[code] = name
    return mapping


def _safe_read_lines_count(path: Path) -> int:
    with path.open("r", encoding="utf-8") as f:
        return sum(1 for _ in f)


def _read_seg_labels(path: Path) -> List[int]:
    labels: List[int] = []
    with path.open("r", encoding="utf-8") as f:
        for ln in f:
            s = ln.strip()
            if not s:
                continue
            try:
                v = int(s)
            except Exception:
                v = int(float(s))
            labels.append(v)
    return labels


def _stratified_split_with_part_coverage(samples: List[Tuple[str, Path, Path, Path | None]], test_ratio: float = 0.2) -> Tuple[List[int], List[int]]:
    """
    Split indices of samples into train and test ensuring that the union of part labels in test
    is a subset of that in train (per category). Operates deterministically on input order.

    samples: list of tuples (base_id, points_path, seg_path, seg_img_path_or_None)
    """
    n = len(samples)
    if n == 0:
        return [], []

    label_sets: List[Set[int]] = []
    lengths_pts: List[int] = []
    lengths_seg: List[int] = []
    for _, pts, seg, _ in samples:
        lengths_pts.append(_safe_read_lines_count(pts))
        seg_labels = _read_seg_labels(seg)
        lengths_seg.append(len(seg_labels))
        label_sets.append(set(seg_labels))

    for i in range(n):
        assert lengths_pts[i] == lengths_seg[i], f"Point/seg length mismatch for {samples[i][1]} vs {samples[i][2]}: {lengths_pts[i]} vs {lengths_seg[i]}"

    # Initial deterministic split: last k to test
    test_count = max(1, int(round(n * test_ratio)))
    test_count = min(test_count, n - 1) if n > 1 else 0
    test_idx = list(range(n - test_count, n)) if test_count > 0 else []
    train_idx = list(range(0, n - test_count)) if test_count > 0 else list(range(n))

    def union_labels(idxs: List[int]) -> Set[int]:
        u: Set[int] = set()
        for i in idxs:
            u |= label_sets[i]
        return u

    max_iters = 5 * n + 10
    it = 0
    while it < max_iters and test_idx:
        it += 1
        train_parts = union_labels(train_idx)
        move_from_test = [j for j in list(test_idx) if not label_sets[j].issubset(train_parts)]
        if not move_from_test:
            break
        # Move violating samples from test to train
        for j in move_from_test:
            if j in test_idx:
                test_idx.remove(j)
            if j not in train_idx:
                train_idx.append(j)
        # Try to move some back to test if needed while keeping coverage
        desired_test = max(1, int(round(n * test_ratio)))
        while len(test_idx) < desired_test and train_idx:
            moved = False
            for cand in reversed(train_idx):
                parts_wo = union_labels([i for i in train_idx if i != cand])
                if label_sets[cand].issubset(parts_wo):
                    train_idx.remove(cand)
                    test_idx.append(cand)
                    moved = True
                    break
            if not moved:
                break

    assert set(train_idx).isdisjoint(set(test_idx))
    assert len(train_idx) + len(test_idx) == n
    assert len(train_idx) >= 1 and len(test_idx) >= 1, "Each category should have at least 1 train and 1 test sample"
    return train_idx, test_idx


def _write_csv(path: Path, header: List[str], rows: List[List[str]]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow(r)


def _generate_random_labels(n: int, allowed_parts: List[int], rng: random.Random) -> List[int]:
    if not allowed_parts:
        return [0] * n
    return [rng.choice(allowed_parts) for _ in range(n)]


# ------------------------------
# Public API
# ------------------------------

def prepare(raw: Path, public: Path, private: Path):
    """
    Complete preparation process.

    - raw: absolute path to the raw data directory (contains PartAnnotation/)
    - public: absolute path to write participant-visible files
    - private: absolute path to write hidden ground-truth answers

    Will not delete or modify `raw/` to preserve original files.
    """
    raw = Path(raw).absolute()
    public = Path(public).absolute()
    private = Path(private).absolute()

    assert raw.exists(), f"Raw directory does not exist: {raw}"
    assert (raw / "PartAnnotation").exists(), f"Expected PartAnnotation/ inside raw: {raw}"

    # Source directories
    src_root = raw / "PartAnnotation"

    # Destination directories
    train_points = public / "train" / "points"
    train_seg = public / "train" / "seg"
    train_seg_img = public / "train" / "seg_img"
    test_points = public / "test" / "points"

    # CSV paths
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    # Ensure clean dirs
    train_points.mkdir(parents=True, exist_ok=True)
    train_seg.mkdir(parents=True, exist_ok=True)
    train_seg_img.mkdir(parents=True, exist_ok=True)
    test_points.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Read category mapping
    synset_path = src_root / "synsetoffset2category.txt"
    code_to_name = _read_synset_mapping(synset_path)

    # List category dirs (8-digit numeric)
    category_codes = [d.name for d in sorted(src_root.iterdir()) if d.is_dir() and d.name.isdigit() and len(d.name) == 8]
    assert category_codes, "No category directories found in raw/PartAnnotation."

    # Collect samples per category
    def collect_category_samples(cat_code: str) -> List[Tuple[str, Path, Path, Path | None]]:
        cat_dir = src_root / cat_code
        pts_dir = cat_dir / "points"
        seg_dir = cat_dir / "expert_verified" / "points_label"
        seg_img_dir = cat_dir / "expert_verified" / "seg_img"
        if not (pts_dir.is_dir() and seg_dir.is_dir()):
            return []
        pts_bases = {p.stem for p in pts_dir.glob("*.pts")}
        seg_bases = {p.stem for p in seg_dir.glob("*.seg")}
        common = sorted(pts_bases & seg_bases)
        out: List[Tuple[str, Path, Path, Path | None]] = []
        for base in common:
            points_path = pts_dir / f"{base}.pts"
            seg_path = seg_dir / f"{base}.seg"
            seg_img_path = seg_img_dir / f"{base}.png" if seg_img_dir.is_dir() and (seg_img_dir / f"{base}.png").exists() else None
            out.append((base, points_path, seg_path, seg_img_path))
        return out

    cat_samples: Dict[str, List[Tuple[str, Path, Path, Path | None]]] = {}
    for code in category_codes:
        samps = collect_category_samples(code)
        if samps:
            cat_samples[code] = samps
    total_candidates = sum(len(v) for v in cat_samples.values())
    assert total_candidates > 0, "No matching (points, seg) pairs found in raw data."

    # Split per category with coverage guarantees
    per_cat_splits: Dict[str, Tuple[List[int], List[int]]] = {}
    for code, samples in cat_samples.items():
        tr_idx, te_idx = _stratified_split_with_part_coverage(samples, test_ratio=0.2)
        per_cat_splits[code] = (tr_idx, te_idx)

    # Assign anonymized global IDs
    next_id = 100000
    id_width = 7

    train_rows: List[List[str]] = []
    test_rows: List[List[str]] = []
    test_answer_rows: List[List[str]] = []

    seen_ids: Set[str] = set()
    category_label_coverage_train: Dict[str, Set[int]] = defaultdict(set)
    category_label_coverage_test: Dict[str, Set[int]] = defaultdict(set)

    rng = random.Random(123)  # deterministic sample submission

    for code in sorted(cat_samples.keys()):
        samples = cat_samples[code]
        train_idx, test_idx = per_cat_splits[code]
        cat_name = code_to_name.get(code, code)

        # Build allowed parts from train
        train_allowed_parts: Set[int] = set()
        for i in train_idx:
            _, _, seg_path, _ = samples[i]
            train_allowed_parts |= set(_read_seg_labels(seg_path))
        train_allowed_parts_list = sorted(list(train_allowed_parts))

        # Copy train
        for i in train_idx:
            base, pts_path, seg_path, seg_img_path = samples[i]
            n_pts = _safe_read_lines_count(pts_path)
            seg_labels = _read_seg_labels(seg_path)
            assert n_pts == len(seg_labels)

            anon_id = f"{next_id:0{id_width}d}"
            next_id += 1
            assert anon_id not in seen_ids
            seen_ids.add(anon_id)

            shutil.copyfile(pts_path, train_points / f"{anon_id}.pts")
            shutil.copyfile(seg_path, train_seg / f"{anon_id}.seg")
            if seg_img_path and seg_img_path.exists():
                shutil.copyfile(seg_img_path, train_seg_img / f"{anon_id}.png")

            train_rows.append([anon_id, cat_name, str(n_pts)])
            category_label_coverage_train[cat_name] |= set(seg_labels)

        # Copy test and write answers
        for i in test_idx:
            base, pts_path, seg_path, seg_img_path = samples[i]
            n_pts = _safe_read_lines_count(pts_path)
            seg_labels = _read_seg_labels(seg_path)
            assert n_pts == len(seg_labels)

            anon_id = f"{next_id:0{id_width}d}"
            next_id += 1
            assert anon_id not in seen_ids
            seen_ids.add(anon_id)

            shutil.copyfile(pts_path, test_points / f"{anon_id}.pts")
            # Do NOT copy seg/seg_img to public test

            test_rows.append([anon_id, cat_name, str(n_pts)])
            category_label_coverage_test[cat_name] |= set(seg_labels)
            test_answer_rows.append([anon_id, " ".join(str(v) for v in seg_labels)])

    # Write CSVs
    _write_csv(train_csv, ["id", "category", "n_points"], train_rows)
    _write_csv(test_csv, ["id", "category", "n_points"], test_rows)
    _write_csv(test_answer_csv, ["id", "labels"], test_answer_rows)

    # Sample submission
    cat_to_allowed_parts = {cat: sorted(list(parts)) for cat, parts in category_label_coverage_train.items()}
    sample_rows: List[List[str]] = []
    for rid, cat, n_pts in test_rows:
        n = int(n_pts)
        allowed = cat_to_allowed_parts.get(cat, [0])
        labels = _generate_random_labels(n, allowed, rng)
        sample_rows.append([rid, " ".join(str(x) for x in labels)])
    _write_csv(sample_sub_csv, ["id", "labels"], sample_rows)

    # Copy description.txt to public (if available at repository root)
    repo_desc = Path(__file__).with_name("description.txt")
    if repo_desc.exists():
        # Ensure description mentions paths under public/
        # We simply copy; content already constrained by repository author.
        shutil.copyfile(repo_desc, public / "description.txt")

    # Checks
    assert train_points.is_dir() and train_seg.is_dir()
    assert test_points.is_dir()

    # Train rows match files
    for rid, _, n_pts in train_rows:
        assert (train_points / f"{rid}.pts").exists()
        assert (train_seg / f"{rid}.seg").exists()
        pts_len = _safe_read_lines_count(train_points / f"{rid}.pts")
        seg_len = len(_read_seg_labels(train_seg / f"{rid}.seg"))
        assert pts_len == int(n_pts) == seg_len

    # Test rows match files; and no seg in public test
    for rid, _, _ in test_rows:
        assert (test_points / f"{rid}.pts").exists()
        assert not (public / "test" / "seg" / f"{rid}.seg").exists()

    # Coverage: test parts subset of train parts per category
    for cat in category_label_coverage_test:
        test_parts = category_label_coverage_test[cat]
        train_parts = category_label_coverage_train.get(cat, set())
        assert test_parts.issubset(train_parts), f"Test parts {test_parts} not subset of train parts {train_parts} for category {cat}"

    # Sample submission validity
    test_map_n = {rid: int(n_pts) for rid, _, n_pts in test_rows}
    sub_map = {rid: (labels_str.strip().split(" ") if labels_str.strip() else []) for rid, labels_str in sample_rows}
    assert set(sub_map.keys()) == set(test_map_n.keys())
    for rid, n in test_map_n.items():
        assert len(sub_map[rid]) == n
        for s in sub_map[rid]:
            v = int(s)
            assert v >= 0

    # Determinism sanity
    assert len({rid for rid, _, _ in train_rows}) == len(train_rows)
    assert len({rid for rid, _, _ in test_rows}) == len(test_rows)
