from pathlib import Path
import shutil
import csv
import random
import hashlib
from typing import List, Dict, Tuple

# Deterministic behavior for any randomness
random.seed(42)


def _read_classes(path: Path) -> List[str]:
    with path.open("r", encoding="utf-8") as f:
        classes = [line.strip() for line in f if line.strip()]
    assert len(classes) > 0, "classes.txt is empty"
    return classes


def _parse_yolo_label_file(lbl_path: Path) -> List[Tuple[int, float, float, float, float]]:
    """
    Returns a list of (class_id, cx, cy, w, h) with floats in [0,1].
    Ignores malformed lines and degenerate boxes.
    """
    items: List[Tuple[int, float, float, float, float]] = []
    if not lbl_path.exists():
        return items
    with lbl_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) != 5:
                continue
            try:
                cls = int(float(parts[0]))
                cx = float(parts[1])
                cy = float(parts[2])
                w = float(parts[3])
                h = float(parts[4])
            except Exception:
                continue
            # clip to [0,1]
            cx = min(1.0, max(0.0, cx))
            cy = min(1.0, max(0.0, cy))
            w = min(1.0, max(0.0, w))
            h = min(1.0, max(0.0, h))
            if w <= 0.0 or h <= 0.0:
                continue
            items.append((cls, cx, cy, w, h))
    return items


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the Hardhat + Vest detection dataset for competition use.

    Inputs
    - raw: directory containing the original dataset with structure:
        raw/
          images/{train,val,test}/<image>.jpg
          labels/{train,val,test}/<image>.txt
          labels/classes.txt
    Outputs
    - public/:
        images/train/  (train + val images)
        images/test/
        train.csv        (rows: image_id,class_id,cx,cy,w,h for training set)
        test.csv         (rows: image_id for test set)
        sample_submission.csv
        description.txt  (copied from repository root)
    - private/:
        test_answer.csv  (rows: image_id,class_id,cx,cy,w,h for test set)
    """

    # Ensure dirs
    assert raw.exists() and raw.is_dir(), f"Raw directory not found: {raw}"
    (public / "images" / "train").mkdir(parents=True, exist_ok=True)
    (public / "images" / "test").mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Paths in raw
    raw_images = raw / "images"
    raw_labels = raw / "labels"
    assert (raw_images / "train").exists(), "raw/images/train missing"
    assert (raw_images / "val").exists(), "raw/images/val missing"
    assert (raw_images / "test").exists(), "raw/images/test missing"
    assert (raw_labels / "train").exists(), "raw/labels/train missing"
    assert (raw_labels / "val").exists(), "raw/labels/val missing"
    assert (raw_labels / "test").exists(), "raw/labels/test missing"

    classes = _read_classes(raw_labels / "classes.txt")

    # Copy images: merge train + val into public/images/train
    def _copy_dir(src: Path, dst: Path):
        for p in sorted(src.glob("*.jpg")):
            shutil.copy2(p, dst / p.name)

    _copy_dir(raw_images / "train", public / "images" / "train")
    _copy_dir(raw_images / "val", public / "images" / "train")
    _copy_dir(raw_images / "test", public / "images" / "test")

    # Build train.csv from labels/train + labels/val
    train_rows = 0
    per_class_train_counts: Dict[int, int] = {i: 0 for i in range(len(classes))}

    def _label_for(img_name: str) -> Path:
        # prefer train label, fall back to val
        p1 = raw_labels / "train" / (Path(img_name).stem + ".txt")
        if p1.exists():
            return p1
        return raw_labels / "val" / (Path(img_name).stem + ".txt")

    # Collect unique train image names from merged dirs
    train_imgs = sorted({p.name for p in (public / "images" / "train").glob("*.jpg")})

    with (public / "train.csv").open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["image_id", "class_id", "cx", "cy", "w", "h"])
        for img_name in train_imgs:
            items = _parse_yolo_label_file(_label_for(img_name))
            for (cls, cx, cy, w, h) in items:
                assert 0 <= cls < len(classes), f"Invalid class id {cls} in labels for {img_name}"
                writer.writerow([img_name, cls, f"{cx:.6f}", f"{cy:.6f}", f"{w:.6f}", f"{h:.6f}"])
                train_rows += 1
                per_class_train_counts[cls] = per_class_train_counts.get(cls, 0) + 1

    assert train_rows > 0, "No rows written to train.csv"

    # Build test.csv and private/test_answer.csv from labels/test
    test_imgs = sorted({p.name for p in (public / "images" / "test").glob("*.jpg")})
    assert len(test_imgs) > 0, "No test images found"

    with (public / "test.csv").open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["image_id"])
        for img in test_imgs:
            writer.writerow([img])

    per_class_test_counts: Dict[int, int] = {i: 0 for i in range(len(classes))}
    gt_objects = 0
    with (private / "test_answer.csv").open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["image_id", "class_id", "cx", "cy", "w", "h"])
        for img in test_imgs:
            lbl = raw_labels / "test" / (Path(img).stem + ".txt")
            items = _parse_yolo_label_file(lbl)
            for (cls, cx, cy, w, h) in items:
                assert 0 <= cls < len(classes), f"Invalid class id {cls} in labels for {img}"
                writer.writerow([img, cls, f"{cx:.6f}", f"{cy:.6f}", f"{w:.6f}", f"{h:.6f}"])
                per_class_test_counts[cls] = per_class_test_counts.get(cls, 0) + 1
                gt_objects += 1

    # Ensure classes present in test also appear in train
    for cid, cnt in per_class_test_counts.items():
        if cnt > 0:
            assert (
                per_class_train_counts.get(cid, 0) > 0
            ), f"Class id {cid} appears in test but not in train"

    # Build sample_submission.csv with deterministic pseudo predictions
    # Read test ids back from public/test.csv to guarantee ordering
    test_ids: List[str] = []
    with (public / "test.csv").open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            test_ids.append(row["image_id"])

    # available classes (at least 1)
    available_classes = [cid for cid, c in per_class_test_counts.items() if c > 0]
    if not available_classes:
        available_classes = list(range(len(classes)))

    with (public / "sample_submission.csv").open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["image_id", "PredictionString"])
        for img_id in test_ids:
            # deterministic number of predictions: 1-3 based on stable md5 hash
            md5 = hashlib.md5(img_id.encode("utf-8")).hexdigest()
            seed_int = int(md5[:8], 16)
            rng = random.Random(seed_int)
            k = 1 + (seed_int % 3)
            parts: List[str] = []
            for _ in range(k):
                cls = rng.choice(available_classes)
                conf = max(1e-6, min(0.999999, rng.random()))
                cx = max(0.0, min(1.0, rng.random()))
                cy = max(0.0, min(1.0, rng.random()))
                w = max(1e-6, min(1.0, rng.random() * 0.6))
                h = max(1e-6, min(1.0, rng.random() * 0.6))
                parts += [
                    str(int(cls)),
                    f"{conf:.6f}",
                    f"{cx:.6f}",
                    f"{cy:.6f}",
                    f"{w:.6f}",
                    f"{h:.6f}",
                ]
            writer.writerow([img_id, " ".join(parts)])

    # Copy description.txt into public/
    repo_desc = Path(__file__).parent / "description.txt"
    if repo_desc.exists():
        shutil.copy2(repo_desc, public / "description.txt")

    # Basic checks and invariants
    assert (public / "images" / "train").exists(), "public/images/train should exist"
    assert (public / "images" / "test").exists(), "public/images/test should exist"
    assert (public / "train.csv").exists(), "public/train.csv should exist"
    assert (public / "test.csv").exists(), "public/test.csv should exist"
    assert (public / "sample_submission.csv").exists(), "public/sample_submission.csv should exist"
    assert (private / "test_answer.csv").exists(), "private/test_answer.csv should exist"

    # Ensure there are no path separators in image_id columns
    for csv_path in [public / "train.csv", public / "test.csv", private / "test_answer.csv"]:
        with csv_path.open("r", encoding="utf-8") as f:
            reader = csv.reader(f)
            header = next(reader)
            for row in reader:
                image_id = row[0]
                assert "/" not in image_id and "\\" not in image_id, (
                    f"Found path separator in image_id in {csv_path}: {image_id}"
                )
