from pathlib import Path
import csv
import os
import random
import shutil
import tarfile
from typing import Tuple, Dict

# Deterministic behavior
random.seed(0)


def _extract_tgz(archive: Path, dst: Path) -> None:
    dst.mkdir(parents=True, exist_ok=True)
    with tarfile.open(archive, "r:gz") as tar:
        tar.extractall(path=dst)


def _read_indexed_file(path: Path, expect_cols: int | None = None) -> list[list[str]]:
    rows: list[list[str]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if expect_cols is not None and len(parts) < expect_cols:
                raise ValueError(f"Unexpected format in {path}: {line}")
            rows.append(parts)
    return rows


def _load_raw_metadata(raw_root: Path) -> Tuple[Dict[int, str], Dict[int, int], Dict[int, str], Dict[int, bool]]:
    cub_dir = raw_root / "CUB_200_2011"
    images_rows = _read_indexed_file(cub_dir / "images.txt", expect_cols=2)
    idx_to_relpath: Dict[int, str] = {int(i): r for i, r in images_rows}

    label_rows = _read_indexed_file(cub_dir / "image_class_labels.txt", expect_cols=2)
    idx_to_classid: Dict[int, int] = {int(i): int(c) for i, c in label_rows}

    class_rows = _read_indexed_file(cub_dir / "classes.txt", expect_cols=2)
    classid_to_name: Dict[int, str] = {int(cid): "_".join(name_parts) for cid, *name_parts in class_rows}

    split_rows = _read_indexed_file(cub_dir / "train_test_split.txt", expect_cols=2)
    idx_to_is_train: Dict[int, bool] = {int(i): (int(t) == 1) for i, t in split_rows}

    n = len(idx_to_relpath)
    assert n == len(idx_to_classid) == len(idx_to_is_train), "Metadata length mismatch"

    all_idx = sorted(idx_to_relpath)
    assert all_idx[0] == 1 and all_idx[-1] == n, "Indices must be contiguous starting at 1"

    return idx_to_relpath, idx_to_classid, classid_to_name, idx_to_is_train


def _ensure_dirs(public: Path) -> None:
    for d in [
        public / "train_images",
        public / "test_images",
        public / "train_segmentations",
        public / "test_segmentations",
    ]:
        d.mkdir(parents=True, exist_ok=True)


def prepare(raw: Path, public: Path, private: Path):
    """
    Prepare the dataset for the competition.

    - Extract raw archives in raw/
    - Create train/test CSVs and copy images and segmentation masks to public/
    - Place test_answer.csv in private/
    - Create sample_submission.csv in public/
    - Copy description.txt into public/
    """

    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # 1) Extract archives if needed
    cub_tgz = raw / "CUB_200_2011.tgz"
    seg_tgz = raw / "segmentations.tgz"
    if not (raw / "CUB_200_2011").exists():
        assert cub_tgz.exists(), f"Missing archive: {cub_tgz}"
        _extract_tgz(cub_tgz, raw)
    if not (raw / "segmentations").exists():
        assert seg_tgz.exists(), f"Missing archive: {seg_tgz}"
        _extract_tgz(seg_tgz, raw)

    # 2) Load metadata from raw
    idx_to_relpath, idx_to_classid, classid_to_name, idx_to_is_train = _load_raw_metadata(raw)

    # Build normalized class labels (ensure no spaces)
    classid_to_label: Dict[int, str] = {}
    for cid, cname in sorted(classid_to_name.items()):
        label = cname.strip().replace(" ", "_")
        classid_to_label[cid] = label

    # 3) Prepare output directories
    _ensure_dirs(public)

    train_rows: list[tuple[str, str]] = []
    test_rows: list[tuple[str]] = []
    test_answer_rows: list[tuple[str, str]] = []

    # Copy files
    missing_images: list[Path] = []
    missing_masks: list[Path] = []

    for idx in sorted(idx_to_relpath.keys()):
        relpath = idx_to_relpath[idx]
        class_id = idx_to_classid[idx]
        label = classid_to_label[class_id]

        src_img = raw / "CUB_200_2011" / "images" / relpath
        class_dir = Path(relpath).parts[0]
        base = Path(relpath).stem
        src_mask = raw / "segmentations" / class_dir / f"{base}.png"

        new_id = f"img_{idx:06d}"
        if idx_to_is_train[idx]:
            dst_img = public / "train_images" / f"{new_id}.jpg"
            dst_mask = public / "train_segmentations" / f"{new_id}.png"
        else:
            dst_img = public / "test_images" / f"{new_id}.jpg"
            dst_mask = public / "test_segmentations" / f"{new_id}.png"

        if not src_img.exists():
            missing_images.append(src_img)
        else:
            shutil.copyfile(src_img, dst_img)
        if not src_mask.exists():
            missing_masks.append(src_mask)
        else:
            shutil.copyfile(src_mask, dst_mask)

        if idx_to_is_train[idx]:
            train_rows.append((new_id, label))
        else:
            test_rows.append((new_id,))
            test_answer_rows.append((new_id, label))

    if missing_images:
        raise FileNotFoundError(f"Missing {len(missing_images)} images, e.g. {missing_images[:3]}")
    if missing_masks:
        raise FileNotFoundError(f"Missing {len(missing_masks)} segmentation masks, e.g. {missing_masks[:3]}")

    # 4) Write CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    classes_csv = public / "classes.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_sub_csv = public / "sample_submission.csv"

    with train_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for r in train_rows:
            w.writerow(r)

    with test_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id"])
        for (i,) in test_rows:
            w.writerow([i])

    with test_answer_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for r in test_answer_rows:
            w.writerow(r)

    with classes_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["label"])
        for cid in sorted(classid_to_label):
            w.writerow([classid_to_label[cid]])

    # Create sample submission using only labels that appear in the test set (to satisfy validation)
    allowed_labels = sorted({l for _, l in test_answer_rows})
    with sample_sub_csv.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "label"])
        for (i,) in test_rows:
            w.writerow([i, random.choice(allowed_labels)])

    # 5) Copy description.txt into public/
    root_desc = Path(__file__).parent / "description.txt"
    if root_desc.exists():
        shutil.copyfile(root_desc, public / "description.txt")

    # 6) Validations
    def _read_rows(p: Path):
        with p.open("r", encoding="utf-8") as f:
            return list(csv.DictReader(f))

    train_rows_df = _read_rows(train_csv)
    test_rows_df = _read_rows(test_csv)
    test_ans_df = _read_rows(test_answer_csv)

    assert len(train_rows_df) > 0 and len(test_rows_df) > 0, "Empty train or test"
    assert {r["id"] for r in test_rows_df} == {r["id"] for r in test_ans_df}, "test vs answer ids mismatch"

    # Files existence
    for r in train_rows_df:
        iid = r["id"]
        assert (public / "train_images" / f"{iid}.jpg").exists(), f"Missing train image {iid}"
        assert (public / "train_segmentations" / f"{iid}.png").exists(), f"Missing train mask {iid}"
    for r in test_rows_df:
        iid = r["id"]
        assert (public / "test_images" / f"{iid}.jpg").exists(), f"Missing test image {iid}"
        assert (public / "test_segmentations" / f"{iid}.png").exists(), f"Missing test mask {iid}"

    # Ensure no leakage in public/
    assert not (public / "test_answer.csv").exists(), "test_answer.csv must NOT be in public/"
