from __future__ import annotations

import csv
import json
import random
import re
import shutil
from pathlib import Path
from typing import Dict, List, Tuple

from PIL import Image


# Competition classes
CLASSES = ["rach", "mop_lom", "tray_son", "mat_bo_phan"]
CLASS_SET = set(CLASSES)

# Determinism
SEED = 1337
random.seed(SEED)


def _polygon_to_bbox(xs: List[float], ys: List[float], w: int, h: int) -> Tuple[int, int, int, int] | None:
    if not xs or not ys:
        return None
    x_min = max(0, int(min(xs)))
    y_min = max(0, int(min(ys)))
    x_max = min(max(0, w - 1), int(max(xs)))
    y_max = min(max(0, h - 1), int(max(ys)))
    if x_max <= x_min or y_max <= y_min:
        return None
    return x_min, y_min, x_max, y_max


def _load_annotations(json_paths: List[Path]) -> Dict[str, Dict]:
    data: Dict[str, Dict] = {}
    for p in json_paths:
        if not p.exists():
            continue
        with p.open("r", encoding="utf-8") as f:
            part = json.load(f)
            data.update(part)
    return data


def _discover_image_path(fname: str, src_train_dir: Path, src_val_dir: Path) -> Path | None:
    p1 = src_train_dir / fname
    if p1.exists():
        return p1
    p2 = src_val_dir / fname
    if p2.exists():
        return p2
    return None


def _new_name(idx: int) -> str:
    return f"img_{idx:06d}.jpg"


def prepare(raw: Path, public: Path, private: Path):
    """
    Create the competition structure from the raw dataset.

    Inputs:
    - raw/: contains image/ image/, validation/ validation/, and 0Train_via_annos.json, 0Val_via_annos.json
    Outputs:
    - public/: train_images/, test_images/, train_annotations.csv, train.csv, test.csv, sample_submission.csv, description.txt
    - private/: test_answer.csv
    """

    # Resolve absolute paths
    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    # Source paths in raw
    src_train_dir = raw / "image" / "image"
    src_val_dir = raw / "validation" / "validation"
    train_json = raw / "0Train_via_annos.json"
    val_json = raw / "0Val_via_annos.json"

    # Output paths in public/private
    out_train_dir = public / "train_images"
    out_test_dir = public / "test_images"
    train_csv_path = public / "train_annotations.csv"
    test_answer_csv_path = private / "test_answer.csv"
    sample_sub_path = public / "sample_submission.csv"
    pub_train_list_csv = public / "train.csv"
    pub_test_list_csv = public / "test.csv"

    # Create dirs
    out_train_dir.mkdir(parents=True, exist_ok=True)
    out_test_dir.mkdir(parents=True, exist_ok=True)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Load annotations
    annos = _load_annotations([train_json, val_json])
    assert len(annos) > 0, "No annotations loaded."

    # Aggregate image metadata and boxes
    images: List[Dict] = []
    for fname, rec in annos.items():
        img_path = _discover_image_path(fname, src_train_dir, src_val_dir)
        if img_path is None:
            continue
        try:
            with Image.open(img_path) as im:
                w, h = im.size
        except Exception:
            continue
        boxes: List[Tuple[str, int, int, int, int]] = []
        for r in (rec.get("regions") or []):
            cls = r.get("class")
            if cls not in CLASS_SET:
                continue
            xs = r.get("all_x") or []
            ys = r.get("all_y") or []
            bbox = _polygon_to_bbox(xs, ys, w, h)
            if bbox is None:
                continue
            boxes.append((cls, *bbox))
        images.append({
            "src": img_path,
            "name": fname,
            "size": (w, h),
            "boxes": boxes,
        })

    assert images, "No valid images collected."

    # Deterministic shuffle and split
    random.Random(SEED).shuffle(images)
    test_fraction = 0.2
    target_test = int(round(test_fraction * len(images)))

    # Greedy split that tries to retain all classes in train
    train_idx = set(range(len(images)))
    test_idx: set[int] = set()

    img_classes = [set(b[0] for b in img["boxes"]) for img in images]

    # seed test with one image per class if possible
    for c in CLASSES:
        for i, cls_set in enumerate(img_classes):
            if c in cls_set and i not in test_idx:
                test_idx.add(i)
                train_idx.discard(i)
                break

    # Fill up to target
    for i in range(len(images)):
        if len(test_idx) >= target_test:
            break
        if i in test_idx:
            continue
        # ensure moving i won't remove a class from train entirely
        ok = True
        for c in img_classes[i]:
            if all((c not in img_classes[j]) for j in train_idx if j != i):
                ok = False
                break
        if ok:
            test_idx.add(i)
            train_idx.discard(i)

    # If still short, move images with no labels first
    if len(test_idx) < target_test:
        for i in range(len(images)):
            if len(test_idx) >= target_test:
                break
            if i in test_idx:
                continue
            if not img_classes[i]:
                test_idx.add(i)
                train_idx.discard(i)

    assert train_idx and test_idx, "Empty split."

    # Copy with anonymized names and write rows
    next_id = 1
    new_name: Dict[int, str] = {}

    label_leak_pattern = re.compile("|".join(re.escape(c) for c in CLASSES), re.IGNORECASE)

    train_rows: List[Tuple[str, str, int, int, int, int]] = []
    test_rows: List[Tuple[str, str, int, int, int, int]] = []

    for split, indices, out_dir, sink in (
        ("train", sorted(list(train_idx)), out_train_dir, train_rows),
        ("test", sorted(list(test_idx)), out_test_dir, test_rows),
    ):
        for i in indices:
            img = images[i]
            new_fname = _new_name(next_id)
            next_id += 1
            assert not label_leak_pattern.search(new_fname)
            shutil.copy2(img["src"], out_dir / new_fname)
            new_name[i] = new_fname
            for (cls, x1, y1, x2, y2) in img["boxes"]:
                sink.append((new_fname, cls, x1, y1, x2, y2))

    # Write annotations
    with train_csv_path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["image_id", "class", "x_min", "y_min", "x_max", "y_max"])
        for row in train_rows:
            w.writerow(row)

    with test_answer_csv_path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["image_id", "class", "x_min", "y_min", "x_max", "y_max"])
        for row in test_rows:
            w.writerow(row)

    # train.csv and test.csv (lists of images)
    train_ids = sorted({r[0] for r in train_rows} | set(p.name for p in out_train_dir.iterdir()))
    test_ids = sorted({r[0] for r in test_rows} | set(p.name for p in out_test_dir.iterdir()))

    with pub_train_list_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["image_id"])  # for detection, ids only
        for iid in train_ids:
            w.writerow([iid])

    with pub_test_list_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["image_id"])  # ids only
        for iid in test_ids:
            w.writerow([iid])

    # Build sample submission: one row per test image, possibly empty predictions
    # get sizes to sample within bounds
    size_cache: Dict[str, Tuple[int, int]] = {}
    for iid in test_ids:
        with Image.open(out_test_dir / iid) as im:
            size_cache[iid] = im.size

    def random_prediction_for(iid: str) -> str:
        w, h = size_cache[iid]
        k = random.choice([0, 1, 2])
        parts: List[str] = []
        for _ in range(k):
            cls = random.choice(CLASSES)
            score = random.random()
            x1 = random.randint(0, max(0, w - 2))
            y1 = random.randint(0, max(0, h - 2))
            x2 = random.randint(x1 + 1, w - 1)
            y2 = random.randint(y1 + 1, h - 1)
            parts.extend([cls, f"{score:.4f}", str(x1), str(y1), str(x2), str(y2)])
        return " ".join(parts)

    with sample_sub_path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["image_id", "PredictionString"])
        for iid in test_ids:
            w.writerow([iid, random_prediction_for(iid)])

    # Copy description.txt into public/
    repo_desc = Path(__file__).parent / "description.txt"
    if repo_desc.exists():
        (public / "description.txt").write_text(repo_desc.read_text(encoding="utf-8"), encoding="utf-8")

    # Checks
    assert out_train_dir.exists() and out_test_dir.exists(), "Image directories missing."
    assert train_csv_path.exists(), "train_annotations.csv missing"
    assert test_answer_csv_path.exists(), "test_answer.csv missing"
    assert sample_sub_path.exists(), "sample_submission.csv missing"
    assert pub_train_list_csv.exists() and pub_test_list_csv.exists(), "train.csv/test.csv missing"

    # Ensure no overlap of filenames
    assert set(p.name for p in out_train_dir.iterdir()).isdisjoint(
        set(p.name for p in out_test_dir.iterdir())
    ), "Train and test image filenames overlap"

    # Ensure sample submission covers all test images
    with sample_sub_path.open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        sub_ids = [r["image_id"] for r in reader]
    assert set(sub_ids) == set(test_ids), "sample_submission does not cover all test images"
