from __future__ import annotations

import csv
import os
import random
import shutil
from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd

# Deterministic behavior for any random choices
random.seed(42)

# Types
Box = Tuple[float, float, float, float]


def _yolo_to_xyxy_norm(line: str) -> Box:
    parts = [p for p in line.strip().split() if p != ""]
    if len(parts) < 5:
        raise ValueError(f"Malformed label line: '{line}'")
    # ignore class id (single class)
    _, cx, cy, w, h = parts[:5]
    cx = float(cx)
    cy = float(cy)
    w = float(w)
    h = float(h)
    x1 = max(0.0, min(1.0, cx - w / 2.0))
    y1 = max(0.0, min(1.0, cy - h / 2.0))
    x2 = max(0.0, min(1.0, cx + w / 2.0))
    y2 = max(0.0, min(1.0, cy + h / 2.0))
    if x2 < x1:
        x1, x2 = x2, x1
    if y2 < y1:
        y1, y2 = y2, y1
    return (x1, y1, x2, y2)


def _read_label_file(path: Path) -> List[Box]:
    if not path.exists():
        return []
    with path.open("r", encoding="utf-8") as f:
        lines = [ln for ln in f.read().strip().splitlines() if ln.strip() != ""]
    boxes: List[Box] = []
    for ln in lines:
        try:
            boxes.append(_yolo_to_xyxy_norm(ln))
        except Exception:
            # Skip malformed lines gracefully
            continue
    return boxes


def _boxes_to_string(boxes: List[Box]) -> str:
    if not boxes:
        return ""
    vals: List[str] = []
    for (x1, y1, x2, y2) in boxes:
        vals.extend([f"{x1:.6f}", f"{y1:.6f}", f"{x2:.6f}", f"{y2:.6f}"])
    return " ".join(vals)


def _write_csv(path: Path, header: List[str], rows: List[List[str]]):
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(header)
        writer.writerows(rows)


def _link_or_copy(src: Path, dst: Path):
    if dst.exists():
        return
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        os.link(src, dst)
    except Exception:
        shutil.copy2(src, dst)


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete data preparation.

    Inputs:
    - raw: path to directory that contains SKU110K_fixed/ with images/ and labels/
    - public: destination directory for all participant-visible files
    - private: destination directory for hidden test answers

    Artifacts created (under `public`):
    - images/train/*.jpg
    - images/test/*.jpg
    - train.csv (image_id, boxes)
    - test.csv (image_id)
    - sample_submission.csv (image_id, PredictionString)
    - description.txt (copied from repository root if available)

    Hidden (under `private`):
    - test_answer.csv (image_id, boxes)
    """

    raw = Path(raw).resolve()
    public = Path(public).resolve()
    private = Path(private).resolve()

    # Source directories
    src_root = raw / "SKU110K_fixed"
    src_img = {
        "train": src_root / "images" / "train",
        "val": src_root / "images" / "val",
        "test": src_root / "images" / "test",
    }
    src_lbl = {
        "train": src_root / "labels" / "train",
        "val": src_root / "labels" / "val",
        "test": src_root / "labels" / "test",
    }

    # Sanity checks
    assert src_img["train"].exists() and src_lbl["train"].exists(), "Source train split missing"
    assert src_img["val"].exists() and src_lbl["val"].exists(), "Source val split missing"
    assert src_img["test"].exists() and src_lbl["test"].exists(), "Source test split missing"

    # Target directories
    img_train_dir = public / "images" / "train"
    img_test_dir = public / "images" / "test"
    img_train_dir.mkdir(parents=True, exist_ok=True)
    img_test_dir.mkdir(parents=True, exist_ok=True)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Deterministic mapping: use all train+val as training, test as testing
    train_imgs = sorted(src_img["train"].glob("*.jpg"))
    val_imgs = sorted(src_img["val"].glob("*.jpg"))
    test_imgs = sorted(src_img["test"].glob("*.jpg"))

    train_pool = train_imgs + val_imgs

    # Create filename mapping to anonymized ids
    mapping: Dict[Path, Path] = {}
    idx = 1
    for p in train_pool:
        new_name = f"img_{idx:06d}.jpg"
        mapping[p] = img_train_dir / new_name
        idx += 1
    for p in test_imgs:
        new_name = f"img_{idx:06d}.jpg"
        mapping[p] = img_test_dir / new_name
        idx += 1

    # Copy/link images
    for src, dst in mapping.items():
        _link_or_copy(src, dst)

    # Build CSVs
    train_rows: List[List[str]] = []
    test_rows: List[List[str]] = []
    test_ans_rows: List[List[str]] = []

    def lbl_for(img_path: Path, split: str) -> Path:
        return src_lbl[split] / (img_path.stem + ".txt")

    # Training rows
    for src in train_pool:
        dst = mapping[src]
        split = "train" if src.parent.name == "train" else "val"
        gt_boxes = _read_label_file(lbl_for(src, split))
        train_rows.append([dst.name, _boxes_to_string(gt_boxes)])

    # Test rows and hidden answers
    for src in test_imgs:
        dst = mapping[src]
        test_rows.append([dst.name])
        gt_boxes = _read_label_file(lbl_for(src, "test"))
        test_ans_rows.append([dst.name, _boxes_to_string(gt_boxes)])

    # Write public/private CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_ans_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    _write_csv(train_csv, ["image_id", "boxes"], train_rows)
    _write_csv(test_csv, ["image_id"], test_rows)
    _write_csv(test_ans_csv, ["image_id", "boxes"], test_ans_rows)

    # Build sample submission with random but valid predictions per image
    # number of boxes roughly around mean of GT counts
    gt_counts = [len(_read_label_file(lbl_for(p, "test"))) for p in test_imgs]
    mean_count = max(1, int(sum(gt_counts) / max(1, len(gt_counts))))

    sample_rows: List[List[str]] = []
    for [img_id] in test_rows:
        k = max(0, mean_count + random.randint(-3, 3))
        vals: List[str] = []
        for _ in range(k):
            score = random.random() * 0.9 + 0.05  # [0.05, 0.95]
            x1 = random.random() * 0.9
            y1 = random.random() * 0.9
            w = random.random() * 0.1
            h = random.random() * 0.1
            x2 = min(1.0, x1 + w)
            y2 = min(1.0, y1 + h)
            vals.extend([
                f"{score:.4f}",
                f"{x1:.6f}",
                f"{y1:.6f}",
                f"{x2:.6f}",
                f"{y2:.6f}",
            ])
        sample_rows.append([img_id, " ".join(vals)])
    _write_csv(sample_sub_csv, ["image_id", "PredictionString"], sample_rows)

    # Copy description.txt to public if available in repo
    repo_desc = Path(__file__).parent.resolve() / "description.txt"
    if repo_desc.exists():
        shutil.copy2(repo_desc, public / "description.txt")

    # Post checks
    # 1) No split leakage in filenames
    for p in list(img_train_dir.glob("*.jpg")) + list(img_test_dir.glob("*.jpg")):
        name = p.name.lower()
        assert all(s not in name for s in ["train_", "val_", "test_"]), f"Filename leaks split info: {p.name}"

    # 2) CSV and image counts align
    assert len(list(img_train_dir.glob("*.jpg"))) == len(train_rows), "Mismatch train images vs CSV rows"
    assert len(list(img_test_dir.glob("*.jpg"))) == len(test_rows), "Mismatch test images vs CSV rows"

    # 3) Verify CSV schemas
    def _parse_boxes_str(s: str) -> List[Box]:
        if s.strip() == "":
            return []
        nums = [float(x) for x in s.strip().split()]
        assert len(nums) % 4 == 0, "Boxes string should be multiples of 4 values"
        out: List[Box] = []
        for i in range(0, len(nums), 4):
            x1, y1, x2, y2 = nums[i : i + 4]
            assert 0.0 <= x1 <= 1.0 and 0.0 <= y1 <= 1.0 and 0.0 <= x2 <= 1.0 and 0.0 <= y2 <= 1.0, "Coords must be in [0,1]"
            assert x2 >= x1 and y2 >= y1, "Invalid box with negative size"
            out.append((x1, y1, x2, y2))
        return out

    for _, row in pd.read_csv(train_csv).iterrows():
        _ = _parse_boxes_str(str(row["boxes"]))
    for _, row in pd.read_csv(test_ans_csv).iterrows():
        _ = _parse_boxes_str(str(row["boxes"]))

    # 4) Ensure training set has at least one labeled box
    total_train_boxes = 0
    for s in pd.read_csv(train_csv)["boxes"].astype(str).tolist():
        if s.strip():
            total_train_boxes += len(s.split()) // 4
    assert total_train_boxes > 0, "Training set must contain at least one labeled box"

    # 5) Ensure public test.csv contains only image_id
    with open(test_csv, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        header = next(reader)
        assert header == ["image_id"], "test.csv should only contain image_id"

    # 6) Private answers and public ids alignment
    pub_ids = set(pd.read_csv(test_csv)["image_id"].astype(str).tolist())
    prv_ids = set(pd.read_csv(test_ans_csv)["image_id"].astype(str).tolist())
    assert pub_ids == prv_ids, "Mismatch between public test.csv and private test_answer.csv"

    # Final confirmation
    return
