import os
import csv
import math
import shutil
from pathlib import Path
from typing import List, Dict, Tuple

import pandas as pd
from PIL import Image

# NOTE: Do not include a main() function. The single entrypoint must be:
#   def prepare(raw: Path, public: Path, private: Path)


def _ensure_clean_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)
    # clean contents deterministically
    for child in list(p.iterdir()):
        if child.is_file() or child.is_symlink():
            child.unlink()
        elif child.is_dir():
            shutil.rmtree(child)


def _collect_images(dir_paths: List[Path]) -> List[Path]:
    out: List[Path] = []
    for d in dir_paths:
        if not d.exists():
            continue
        for ext in ("*.jpg", "*.jpeg", "*.png", "*.JPG", "*.JPEG", "*.PNG"):
            out.extend(sorted(d.glob(ext)))
    return out


def _yolo_to_xyxy(cx: float, cy: float, w: float, h: float, img_w: int, img_h: int) -> Tuple[int, int, int, int]:
    # YOLO normalized center x/y, width/height -> xyxy integer pixel box (inclusive)
    x_min = (cx - w / 2.0) * img_w
    y_min = (cy - h / 2.0) * img_h
    x_max = (cx + w / 2.0) * img_w
    y_max = (cy + h / 2.0) * img_h
    x_min = int(max(0, math.floor(x_min)))
    y_min = int(max(0, math.floor(y_min)))
    x_max = int(min(img_w - 1, math.ceil(x_max)))
    y_max = int(min(img_h - 1, math.ceil(y_max)))
    x_max = max(x_max, x_min)
    y_max = max(y_max, y_min)
    return x_min, y_min, x_max, y_max


essential_default_classes = ['HMV', 'LMV', 'Pedestrian', 'RoadDamages', 'SpeedBump', 'UnsurfacedRoad']


def _parse_yaml_classes(yaml_path: Path) -> List[str]:
    if not yaml_path.exists():
        return list(essential_default_classes)
    content = yaml_path.read_text(encoding="utf-8")
    lines = [ln.strip() for ln in content.splitlines() if ln.strip()]
    for i, ln in enumerate(lines):
        if ln.startswith("names:"):
            rhs = ln.split(":", 1)[1].strip()
            if rhs.startswith("[") and rhs.endswith("]"):
                items = rhs[1:-1]
                names = [x.strip().strip("'\"") for x in items.split(",") if x.strip()]
                return names if names else list(essential_default_classes)
            # multi-line list
            names: List[str] = []
            j = i + 1
            while j < len(lines) and lines[j].startswith("-"):
                names.append(lines[j].lstrip("-").strip().strip("'\""))
                j += 1
            return names if names else list(essential_default_classes)
    return list(essential_default_classes)


def _read_yolo_labels(lbl_path: Path) -> List[Tuple[int, float, float, float, float]]:
    # return list of (cls_id, cx, cy, w, h) in normalized coords
    if not lbl_path.exists():
        return []
    items: List[Tuple[int, float, float, float, float]] = []
    with lbl_path.open("r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split()
            if len(parts) != 5:
                # skip malformed rows rather than failing hard
                continue
            try:
                cls_id = int(parts[0])
                cx, cy, w, h = map(float, parts[1:])
            except Exception:
                continue
            if not (0 <= cx <= 1 and 0 <= cy <= 1 and 0 <= w <= 1 and 0 <= h <= 1):
                continue
            items.append((cls_id, cx, cy, w, h))
    return items


def prepare(raw: Path, public: Path, private: Path):
    """
    Full preparation pipeline.

    Inputs:
      - raw: absolute path to the raw/ directory (contains images/ with train/valid/test)
      - public: absolute path where public artifacts will be written
      - private: absolute path where private artifacts will be written

    Outputs in public/:
      - train.csv: columns [image_id, width, height, class, x_min, y_min, x_max, y_max]
      - test.csv:  column [image_id]
      - sample_submission.csv: columns [image_id, PredictionString]
      - train/images/* and test/images/*: image files
      - description.txt: a copy of the task description

    Outputs in private/:
      - test_answer.csv: columns [image_id, class, x_min, y_min, x_max, y_max]
    """
    assert raw.is_absolute() and public.is_absolute() and private.is_absolute(), "Use absolute paths"

    # Source layout under raw/
    src_root = raw / "images"
    train_img_dir = src_root / "train" / "images"
    train_lbl_dir = src_root / "train" / "labels"
    valid_img_dir = src_root / "valid" / "images"
    valid_lbl_dir = src_root / "valid" / "labels"
    test_img_dir = src_root / "test" / "images"
    test_lbl_dir = src_root / "test" / "labels"
    data_yaml = src_root / "data.yaml"

    classes = _parse_yaml_classes(data_yaml)

    # Prepare output folders
    _ensure_clean_dir(public)
    _ensure_clean_dir(private)

    (public / "train" / "images").mkdir(parents=True, exist_ok=True)
    (public / "test" / "images").mkdir(parents=True, exist_ok=True)

    # copy description.txt into public
    project_root = Path(__file__).resolve().parent
    desc_src = project_root / "description.txt"
    if desc_src.exists():
        shutil.copy2(desc_src, public / "description.txt")

    # collect images
    train_images = _collect_images([train_img_dir, valid_img_dir])
    test_images = _collect_images([test_img_dir])

    assert len(test_images) > 0, "No test images found under raw/images/test/images"
    assert len(train_images) > 0, "No training images found under raw/images/train|valid/images"

    # copy images into public with original names for clarity
    for p in train_images:
        shutil.copy2(p, public / "train" / "images" / p.name)
    for p in test_images:
        shutil.copy2(p, public / "test" / "images" / p.name)

    # build train.csv
    train_csv_path = public / "train.csv"
    with train_csv_path.open("w", newline="", encoding="utf-8") as f:
        wcsv = csv.writer(f)
        wcsv.writerow(["image_id", "width", "height", "class", "x_min", "y_min", "x_max", "y_max"])
        # iterate over train and valid
        for img_path in train_images:
            lbl_dir = train_lbl_dir if img_path.parent == train_img_dir else valid_lbl_dir
            lbl_path = lbl_dir / (img_path.stem + ".txt")
            with Image.open(img_path) as im:
                w, h = im.size
            items = _read_yolo_labels(lbl_path)
            for (cid, cx, cy, ww, hh) in items:
                if 0 <= cid < len(classes):
                    cname = classes[cid]
                else:
                    # unknown class id -> skip
                    continue
                x1, y1, x2, y2 = _yolo_to_xyxy(cx, cy, ww, hh, w, h)
                wcsv.writerow([img_path.name, w, h, cname, x1, y1, x2, y2])

    # build private/test_answer.csv and public/test.csv and public/sample_submission.csv
    test_answer_path = private / "test_answer.csv"
    test_list_path = public / "test.csv"
    sample_sub_path = public / "sample_submission.csv"

    with test_answer_path.open("w", newline="", encoding="utf-8") as f_ans, \
         test_list_path.open("w", newline="", encoding="utf-8") as f_list, \
         sample_sub_path.open("w", newline="", encoding="utf-8") as f_sub:
        w_ans = csv.writer(f_ans)
        w_list = csv.writer(f_list)
        w_sub = csv.writer(f_sub)
        w_ans.writerow(["image_id", "class", "x_min", "y_min", "x_max", "y_max"])  # GT boxes only
        w_list.writerow(["image_id"])  # list of test ids
        w_sub.writerow(["image_id", "PredictionString"])  # empty strings allowed

        for img_path in sorted(test_images, key=lambda p: p.name):
            w_list.writerow([img_path.name])
            w_sub.writerow([img_path.name, ""])  # deterministic empty predictions as baseline

            # write GT rows (if any)
            lbl_path = test_lbl_dir / (img_path.stem + ".txt")
            with Image.open(img_path) as im:
                w, h = im.size
            items = _read_yolo_labels(lbl_path)
            for (cid, cx, cy, ww, hh) in items:
                if 0 <= cid < len(classes):
                    cname = classes[cid]
                else:
                    continue
                x1, y1, x2, y2 = _yolo_to_xyxy(cx, cy, ww, hh, w, h)
                w_ans.writerow([img_path.name, cname, x1, y1, x2, y2])

    # integrity checks
    # 1. public directories exist
    assert (public / "train" / "images").exists()
    assert (public / "test" / "images").exists()

    # 2. CSVs exist
    assert train_csv_path.exists()
    assert test_list_path.exists()
    assert sample_sub_path.exists()
    assert test_answer_path.exists()

    # 3. test.csv ids equal to images in public/test/images
    test_files = sorted([p.name for p in (public / "test" / "images").glob("*") if p.is_file()])
    df_test = pd.read_csv(test_list_path)
    assert sorted(df_test["image_id"].tolist()) == test_files, "Mismatch between test.csv and public/test/images"

    # 4. sample_submission has identical set/order of ids as test.csv
    df_sub = pd.read_csv(sample_sub_path)
    assert df_sub.columns.tolist() == ["image_id", "PredictionString"], "sample_submission must have two columns"
    assert df_sub["image_id"].tolist() == df_test["image_id"].tolist(), "sample_submission ids must match test.csv"

    # 5. train.csv labels within bounds
    df_train = pd.read_csv(train_csv_path)
    for _, r in df_train.iterrows():
        w, h = int(r["width"]), int(r["height"])
        x1, y1, x2, y2 = int(r["x_min"]), int(r["y_min"]), int(r["x_max"]), int(r["y_max"])
        assert 0 <= x1 <= x2 < w, f"x bounds invalid in train.csv row: {r.to_dict()}"
        assert 0 <= y1 <= y2 < h, f"y bounds invalid in train.csv row: {r.to_dict()}"
