from pathlib import Path
import csv
import shutil
import random
from typing import List, Tuple

import numpy as np
import pandas as pd
from PIL import Image

# Deterministic seed for all randomness
SEED = 20240913
random.seed(SEED)
np.random.seed(SEED)


def _read_df(raw: Path) -> list[tuple[Path, Path, Path]]:
    df_path = raw / "df.csv"
    assert df_path.is_file(), f"Missing df.csv at {df_path}"

    rows: list[tuple[Path, Path, Path]] = []
    with df_path.open("r", newline="") as f:
        reader = csv.DictReader(f)
        # Accept optional unnamed index column created by pandas
        expected_cols = {"images", "masks", "collages"}
        assert expected_cols.issubset(set(reader.fieldnames or [])), (
            f"Unexpected df.csv headers: {reader.fieldnames}"
        )
        # Source root where the relative paths in df.csv start from
        src_root = raw / "supervisely_person_clean_2667_img" / "supervisely_person_clean_2667_img"
        for r in reader:
            img_p = src_root / r["images"]
            msk_p = src_root / r["masks"]
            col_p = src_root / r["collages"]
            assert img_p.is_file(), f"Image not found: {img_p}"
            assert msk_p.is_file(), f"Mask not found: {msk_p}"
            assert col_p.is_file(), f"Collage not found: {col_p}"
            rows.append((img_p, msk_p, col_p))
    assert len(rows) >= 100, f"Too few rows in df.csv: {len(rows)}"
    return rows


def _split_indices(n: int, test_ratio: float = 0.2) -> tuple[list[int], list[int]]:
    idxs = list(range(n))
    rng = random.Random(SEED)
    rng.shuffle(idxs)
    test_size = int(round(n * test_ratio))
    test_idx = sorted(idxs[:test_size])
    train_idx = sorted(idxs[test_size:])
    assert len(set(train_idx).intersection(test_idx)) == 0
    assert len(train_idx) + len(test_idx) == n
    assert len(train_idx) > 0 and len(test_idx) > 0
    return train_idx, test_idx


def _safe_copy(src: Path, dst: Path) -> None:
    assert src.is_file(), f"Source not a file: {src}"
    dst.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(src, dst)
    assert dst.is_file(), f"Failed to copy to {dst}"


def _rle_encode(mask: np.ndarray) -> str:
    # mask: HxW binary, encode in column-major (Fortran) order with 1-indexed starts
    assert mask.ndim == 2
    m = (mask.astype(np.uint8).flatten(order="F") > 0).astype(np.uint8)
    if m.sum() == 0:
        return ""
    m = np.concatenate([[0], m, [0]])
    diff = np.diff(m)
    starts = np.where(diff == 1)[0]
    ends = np.where(diff == -1)[0]
    lengths = ends - starts
    pairs: list[str] = []
    for s, l in zip(starts, lengths):
        pairs.append(str(s))  # already 1-indexed because of the leading zero
        pairs.append(str(int(l)))
    return " ".join(pairs)


def _make_ids(indices: List[int]) -> list[str]:
    ids = [f"id_{i:05d}" for i in indices]
    assert len(ids) == len(set(ids))
    return ids


def prepare(raw: Path, public: Path, private: Path):
    # Normalize to absolute paths
    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    # Clean and recreate public/private to ensure deterministic reruns
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Directory layout under public/
    train_images_dir = public / "train" / "images"
    train_masks_dir = public / "train" / "masks"
    train_collage_dir = public / "train" / "collage"
    test_images_dir = public / "test" / "images"
    test_collage_dir = public / "test" / "collage"

    for d in [train_images_dir, train_masks_dir, train_collage_dir, test_images_dir, test_collage_dir]:
        d.mkdir(parents=True, exist_ok=True)

    # Load raw index and spot-check integrity
    rows = _read_df(raw)
    for i, (img_p, msk_p, col_p) in enumerate(rows[:50]):  # spot-check first 50
        with Image.open(img_p) as im:
            w_i, h_i = im.size
        with Image.open(msk_p) as mm:
            w_m, h_m = mm.size
            mm_arr = np.array(mm)
        assert (w_i, h_i) == (w_m, h_m), f"Image/Mask size mismatch at row {i}: {img_p} vs {msk_p}"
        if mm_arr.ndim == 3:
            mm_arr = mm_arr[..., 0]
        assert np.isfinite(mm_arr).all(), f"Non-finite in mask at {msk_p}"

    n = len(rows)
    train_idx, test_idx = _split_indices(n, test_ratio=0.2)
    train_ids = _make_ids(train_idx)
    test_ids = _make_ids(test_idx)

    # Write CSVs
    train_csv = public / "train.csv"
    test_csv = public / "test.csv"
    test_answer_csv = private / "test_answer.csv"
    sample_submission_csv = public / "sample_submission.csv"

    with train_csv.open("w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "image_filename", "mask_filename"])
        for new_id, idx in zip(train_ids, train_idx):
            img_p, msk_p, col_p = rows[idx]
            img_ext = img_p.suffix.lower()
            msk_ext = msk_p.suffix.lower()
            col_ext = col_p.suffix.lower()
            assert img_ext in [".png", ".jpg", ".jpeg"]
            assert msk_ext in [".png", ".jpg", ".jpeg"]
            img_name = f"{new_id}{img_ext}"
            msk_name = f"{new_id}{msk_ext}"
            col_name = f"{new_id}{col_ext}"
            _safe_copy(img_p, train_images_dir / img_name)
            _safe_copy(msk_p, train_masks_dir / msk_name)
            _safe_copy(col_p, train_collage_dir / col_name)
            writer.writerow([new_id, img_name, msk_name])

    with test_csv.open("w", newline="") as f_test, test_answer_csv.open("w", newline="") as f_ans:
        test_writer = csv.writer(f_test)
        ans_writer = csv.writer(f_ans)
        test_writer.writerow(["id", "image_filename"])
        ans_writer.writerow(["id", "rle"])
        for new_id, idx in zip(test_ids, test_idx):
            img_p, msk_p, col_p = rows[idx]
            img_ext = img_p.suffix.lower()
            col_ext = col_p.suffix.lower()
            img_name = f"{new_id}{img_ext}"
            col_name = f"{new_id}{col_ext}"
            _safe_copy(img_p, test_images_dir / img_name)
            _safe_copy(col_p, test_collage_dir / col_name)
            with Image.open(msk_p) as mm:
                mask_arr = np.array(mm)
            if mask_arr.ndim == 3:
                mask_arr = mask_arr[..., 0]
            gt_bin = (mask_arr > 0).astype(np.uint8)
            gt_rle = _rle_encode(gt_bin)
            test_writer.writerow([new_id, img_name])
            ans_writer.writerow([new_id, gt_rle])

    # Create sample submission with random but valid RLEs matching image sizes
    with sample_submission_csv.open("w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "rle"])
        test_df = pd.read_csv(test_csv)
        for _, r in test_df.iterrows():
            img_path = test_images_dir / r["image_filename"]
            with Image.open(img_path) as im:
                w, h = im.size
            p = 0.01  # sparse example
            rand_mask = (np.random.rand(h, w) < p).astype(np.uint8)
            writer.writerow([r["id"], _rle_encode(rand_mask)])

    # Copy description.txt into public/
    root_desc = Path(__file__).resolve().parent / "description.txt"
    assert root_desc.is_file(), f"Missing description.txt at {root_desc}"
    shutil.copy2(root_desc, public / "description.txt")

    # Checks
    assert train_csv.is_file(), "public/train.csv should exist"
    assert test_csv.is_file(), "public/test.csv should exist"
    assert sample_submission_csv.is_file(), "public/sample_submission.csv should exist"
    assert test_answer_csv.is_file(), "private/test_answer.csv should exist"

    def _load_csv(p: Path):
        with p.open("r", newline="") as f:
            return list(csv.DictReader(f))

    train_rows = _load_csv(train_csv)
    test_rows = _load_csv(test_csv)
    ans_rows = _load_csv(test_answer_csv)

    assert len(train_rows) + len(test_rows) == len(rows)
    assert [r["id"] for r in test_rows] == [r["id"] for r in ans_rows]

    # Ensure CSV filenames don't contain any paths
    for r in train_rows:
        assert "/" not in r["image_filename"] and "/" not in r["mask_filename"]
    for r in test_rows:
        assert "/" not in r["image_filename"]

    # Ensure referenced files exist
    for r in train_rows:
        ip = train_images_dir / r["image_filename"]
        mp = train_masks_dir / r["mask_filename"]
        assert ip.is_file() and mp.is_file()
        with Image.open(ip) as im:
            w_i, h_i = im.size
        with Image.open(mp) as mm:
            w_m, h_m = mm.size
        assert (w_i, h_i) == (w_m, h_m)
    for r in test_rows:
        ip = test_images_dir / r["image_filename"]
        assert ip.is_file()

    # Ensure IDs are unique and disjoint
    train_ids_set = {r["id"] for r in train_rows}
    test_ids_set = {r["id"] for r in test_rows}
    assert len(train_ids_set) == len(train_rows)
    assert len(test_ids_set) == len(test_rows)
    assert train_ids_set.isdisjoint(test_ids_set)

    # Ensure RLEs in answers decode to the right shapes
    def _rle_decode(rle: str, shape: Tuple[int, int]) -> np.ndarray:
        h, w = shape
        if rle is None or rle.strip() == "":
            return np.zeros((h, w), dtype=np.uint8)
        s = rle.strip().split()
        assert len(s) % 2 == 0
        starts = np.asarray(s[0::2], dtype=np.int64)
        lengths = np.asarray(s[1::2], dtype=np.int64)
        ends = starts + lengths
        starts -= 1  # 1-indexed to 0-indexed
        flat = np.zeros(h * w, dtype=np.uint8)
        for st, en in zip(starts, ends):
            if st < 0:
                st = 0
            if en > h * w:
                en = h * w
            flat[st:en] = 1
        return flat.reshape((h, w), order="F")

    for r, a in zip(test_rows, ans_rows):
        ip = test_images_dir / r["image_filename"]
        with Image.open(ip) as im:
            w, h = im.size
        _ = _rle_decode(a["rle"], (h, w))

