import sys
import csv
import random
import shutil
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
from PIL import Image

# Increase CSV field size limit to handle long RLEs
try:
    csv.field_size_limit(sys.maxsize)
except OverflowError:
    csv.field_size_limit(2**31 - 1)

# Deterministic behavior
RANDOM_SEED = 1337
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Class mapping (RGB)
# ISPRS labels provided as BGR in docs; reading via PIL yields RGB, so we map RGB equivalents here
CLASS_NAMES = [
    "impervious",  # 0
    "building",    # 1
    "low_vegetation",  # 2
    "tree",        # 3
    "car",         # 4
    "clutter"      # 5
]

RGB_TO_CLASS: Dict[Tuple[int,int,int], int] = {
    (255, 255, 255): 0,  # impervious surfaces
    (255,   0,   0): 1,  # building (BGR 0,0,255)
    (255, 255,   0): 2,  # low vegetation (BGR 0,255,255)
    (0,   255,   0): 3,  # tree (BGR 0,255,0)
    (0,   255, 255): 4,  # car (BGR 255,255,0)
    (0,     0, 255): 5,  # clutter/background (BGR 255,0,0)
}

KNOWN_COLORS = np.array(list(RGB_TO_CLASS.keys()), dtype=np.int16)  # shape (6,3)
KNOWN_CLASS_IDS = np.array([RGB_TO_CLASS[tuple(c)] for c in KNOWN_COLORS], dtype=np.int16)

# Accept also some alternative encodings found in the wild (e.g., class indices 1..6)
VALID_CLASS_VALUES = set(range(6))
VALID_CLASS_VALUES_1_BASED = set(range(1, 7))


def _ensure_dirs(public: Path):
    for rel in [
        "train/images", "train/masks",
        "test/images",
        "extra/images",
    ]:
        p = public / rel
        p.mkdir(parents=True, exist_ok=True)


def _list_source_pairs(raw: Path) -> List[Tuple[Path, Path]]:
    pairs: List[Tuple[Path, Path]] = []
    # Potsdam
    pots_img_dir = raw / "Potsdam/Images"
    pots_lbl_dir = raw / "Potsdam/Labels"
    if pots_img_dir.exists() and pots_lbl_dir.exists():
        for img_path in sorted(pots_img_dir.glob("*.tif")):
            name = img_path.name
            if name.endswith("_RGB.tif"):
                lbl_name = name.replace("_RGB.tif", "_label.tif")
            else:
                stem = img_path.stem
                lbl_name = f"{stem}_label.tif"
            lbl_path = pots_lbl_dir / lbl_name
            if lbl_path.exists():
                pairs.append((img_path, lbl_path))
    # Vaihingen
    vai_img_dir = raw / "Vaihingen/Images"
    vai_lbl_dir = raw / "Vaihingen/Labels"
    if vai_img_dir.exists() and vai_lbl_dir.exists():
        for img_path in sorted(vai_img_dir.glob("*.tif")):
            lbl_path = vai_lbl_dir / img_path.name
            if lbl_path.exists():
                pairs.append((img_path, lbl_path))
    return pairs


def _list_unlabeled_images(raw: Path) -> List[Path]:
    tor_dir = raw / "Toronto/Images"
    if not tor_dir.exists():
        return []
    return [p for p in sorted(tor_dir.glob("*.tif"))]


def _palette_index_to_mask(img: Image.Image) -> np.ndarray:
    assert img.mode == "P"
    arr = np.array(img)  # indices HxW
    uniq_idx = np.unique(arr)
    pal = img.getpalette()  # list length 768
    if pal is None:
        uniq = set(uniq_idx.tolist())
        if uniq.issubset(VALID_CLASS_VALUES):
            return arr.astype(np.uint8)
        elif uniq.issubset(VALID_CLASS_VALUES_1_BASED):
            return (arr.astype(np.int32) - 1).astype(np.uint8)
        else:
            raise AssertionError(
                f"Paletted label without palette and unexpected indices: {sorted(list(uniq))[:10]}"
            )
    pal = np.array(pal, dtype=np.uint8).reshape(-1, 3)
    # Map palette colors to class ids
    idx_to_cid: Dict[int, int] = {}
    for idx in uniq_idx:
        rgb = tuple(pal[int(idx)])
        if rgb in RGB_TO_CLASS:
            idx_to_cid[int(idx)] = RGB_TO_CLASS[rgb]
        else:
            # Nearest color among known ones
            diffs = (KNOWN_COLORS - pal[int(idx)].astype(np.int16))
            d2 = np.sum(diffs * diffs, axis=1)
            near = int(KNOWN_CLASS_IDS[int(np.argmin(d2))])
            idx_to_cid[int(idx)] = near
    # Build mask by vectorized mapping
    lut = np.zeros(256, dtype=np.uint8)
    for k, v in idx_to_cid.items():
        lut[k] = np.uint8(v)
    mask = lut[arr]
    return mask


def _rgb_label_to_mask(arr: np.ndarray) -> np.ndarray:
    h, w, _ = arr.shape
    mask = np.zeros((h, w), dtype=np.uint8)
    matched = np.zeros((h, w), dtype=bool)
    for rgb, cid in RGB_TO_CLASS.items():
        m = (arr[:, :, 0] == rgb[0]) & (arr[:, :, 1] == rgb[1]) & (arr[:, :, 2] == rgb[2])
        mask[m] = np.uint8(cid)
        matched |= m
    if matched.all():
        return mask
    # Fallback: nearest color assignment
    arr16 = arr.astype(np.int16).reshape(-1, 3)
    diffs = arr16[:, None, :] - KNOWN_COLORS[None, :, :]
    d2 = np.sum(diffs * diffs, axis=2)
    idx = np.argmin(d2, axis=1)
    cids = KNOWN_CLASS_IDS[idx]
    mask = cids.reshape(h, w).astype(np.uint8)
    return mask


def _load_label_to_ids(label_path: Path) -> np.ndarray:
    img = Image.open(label_path)
    if img.mode == "P":
        return _palette_index_to_mask(img)
    if img.mode in ("RGB", "RGBA"):
        img = img.convert("RGB")
        arr = np.array(img)
        return _rgb_label_to_mask(arr)
    else:
        # Single channel
        arr = np.array(img)
        if arr.ndim != 2:
            raise AssertionError(
                f"Unexpected label dims for {label_path}: shape={arr.shape}, mode={img.mode}"
            )
        uniq = set(np.unique(arr).tolist())
        if uniq.issubset(VALID_CLASS_VALUES):
            return arr.astype(np.uint8)
        elif uniq.issubset(VALID_CLASS_VALUES_1_BASED):
            return (arr.astype(np.int32) - 1).astype(np.uint8)
        else:
            arr3 = np.stack([arr, arr, arr], axis=-1).astype(np.uint8)
            return _rgb_label_to_mask(arr3)


def _rle_encode(mask: np.ndarray) -> str:
    # Robust RLE encoder (row-major, 1-indexed)
    m = (mask.astype(np.uint8).flatten(order='C'))
    if m.size == 0:
        return ""
    m = (m > 0).astype(np.uint8)
    z = np.concatenate(([0], m, [0])).astype(np.uint8)
    changes = np.where(z[1:] != z[:-1])[0] + 1
    if changes.size == 0:
        return ""
    runs = changes[::2]
    lengths = changes[1::2] - runs
    out = []
    for s, l in zip(runs.tolist(), lengths.tolist()):
        if l > 0:
            out.append(str(int(s)))
            out.append(str(int(l)))
    return " ".join(out)


def _write_csv(path: Path, rows: List[List]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(rows)


def _link_or_copy(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        if dst.exists():
            dst.unlink()
        os_link_supported = True
        # Try hard link; if not allowed, fall back to copy
        try:
            import os
            os.link(str(src), str(dst))
        except Exception:
            os_link_supported = False
        if not os_link_supported:
            shutil.copy2(str(src), str(dst))
    except Exception:
        shutil.copy2(str(src), str(dst))


def prepare(raw: Path, public: Path, private: Path):
    """
    Complete deterministic preparation process.

    - Reads raw data from `raw` (Potsdam/Vaihingen labeled, Toronto unlabeled)
    - Creates anonymized train/test splits deterministically
    - Writes train/test images and masks into `public/`
    - Writes public/train.csv, public/test.csv, public/sample_submission.csv
    - Writes private/test_answer.csv (not visible to participants)
    - Copies description.txt into `public/`
    """

    # Ensure directories
    _ensure_dirs(public)

    # Enumerate labeled image/label pairs
    pairs = _list_source_pairs(raw)
    assert len(pairs) > 0, "No labeled image/label pairs found in raw/."

    # Build per-image metadata
    items = []
    for img_path, lbl_path in pairs:
        with Image.open(img_path) as img:
            w, h = img.size
        mask = _load_label_to_ids(lbl_path)
        assert mask.shape == (h, w), f"Shape mismatch for {img_path}: image {(h,w)} vs mask {mask.shape}"
        present = set(np.unique(mask).tolist()) & VALID_CLASS_VALUES
        items.append({
            'img_path': img_path,
            'lbl_path': lbl_path,
            'h': h,
            'w': w,
            'classes': present,
            'mask': mask,
        })

    # Deterministic shuffle and split
    idxs = list(range(len(items)))
    random.Random(RANDOM_SEED).shuffle(idxs)
    split_ratio = 0.8
    n_train = int(round(split_ratio * len(items)))
    n_train = max(1, min(len(items)-1, n_train))
    train_ids = set(idxs[:n_train])
    test_ids = set(idxs[n_train:])
    assert len(train_ids) + len(test_ids) == len(items)

    # Ensure each class appears in train if present overall
    all_classes = set().union(*[it['classes'] for it in items])
    train_classes = set().union(*[items[i]['classes'] for i in train_ids]) if train_ids else set()
    missing = list(all_classes - train_classes)
    if missing:
        train_ids_list = list(train_ids)
        test_ids_list = list(test_ids)
        for cls in missing:
            candidate_test = next((i for i in test_ids_list if cls in items[i]['classes']), None)
            if candidate_test is not None:
                swap_train = next((i for i in train_ids_list if cls not in items[i]['classes']), None)
                if swap_train is None:
                    train_ids.add(candidate_test)
                    test_ids.discard(candidate_test)
                else:
                    train_ids.add(candidate_test)
                    test_ids.discard(candidate_test)
                    test_ids.add(swap_train)
                    train_ids.discard(swap_train)
        train_classes = set().union(*[items[i]['classes'] for i in train_ids])
        assert all(c in train_classes for c in all_classes), (
            f"Not all classes present in train after fix: missing {sorted(list(all_classes - train_classes))}"
        )

    # Create anonymized names
    id_to_name: Dict[int, str] = {}
    for k, i in enumerate(sorted(train_ids.union(test_ids))):
        id_to_name[i] = f"tile_{k+1:06d}"

    # CSV rows
    train_csv_rows = [["image_id", "height", "width"]]
    test_csv_rows = [["image_id", "height", "width"]]

    # answer and sample submission rows
    test_answer_rows = [["image_id", "class_id", "encoding"]]
    sample_sub_rows = [["image_id", "class_id", "encoding"]]

    # Copy unlabeled extras
    for j, upath in enumerate(_list_unlabeled_images(raw)):
        new_id = f"extra_{j+1:05d}"
        dst = public / "extra/images" / f"{new_id}.tif"
        _link_or_copy(upath, dst)

    # Copy train/test data and generate RLEs
    for i, it in enumerate(items):
        new_name = id_to_name[i]
        h, w = it['h'], it['w']
        if i in train_ids:
            dst_img = public / "train/images" / f"{new_name}.tif"
            _link_or_copy(it['img_path'], dst_img)
            dst_msk = public / "train/masks" / f"{new_name}.png"
            Image.fromarray(it['mask']).save(dst_msk)
            train_csv_rows.append([new_name, h, w])
        else:
            dst_img = public / "test/images" / f"{new_name}.tif"
            _link_or_copy(it['img_path'], dst_img)
            test_csv_rows.append([new_name, h, w])
            # ground truth encodings for private
            for cid in range(6):
                bin_mask = (it['mask'] == cid).astype(np.uint8)
                enc = _rle_encode(bin_mask)
                test_answer_rows.append([new_name, cid, enc])
            # sample submission: deterministic sparse seeds
            num_pixels = h * w
            for cid in range(6):
                k = min(100, max(1, num_pixels // 200000))
                rng = np.random.default_rng(RANDOM_SEED + cid + i)
                idxs = rng.choice(num_pixels, size=k, replace=False)
                m = np.zeros(num_pixels, dtype=np.uint8)
                m[idxs] = 1
                m2d = m.reshape((h, w))
                enc = _rle_encode(m2d)
                sample_sub_rows.append([new_name, cid, enc])

    # Write CSVs
    _write_csv(public / "train.csv", train_csv_rows)
    _write_csv(public / "test.csv", test_csv_rows)
    _write_csv(private / "test_answer.csv", test_answer_rows)
    _write_csv(public / "sample_submission.csv", sample_sub_rows)

    # Copy description.txt to public if exists in root of working dir
    # Try to locate description.txt relative to caller's CWD (commonly repository root)
    repo_desc = Path(__file__).parent / "description.txt"
    if repo_desc.exists():
        shutil.copy2(str(repo_desc), str(public / "description.txt"))

    # Assertions and integrity checks
    assert (public / "train/images").exists()
    assert (public / "train/masks").exists()
    assert (public / "test/images").exists()

    # Load back CSVs for checks
    def _load_csv(path: Path):
        with open(path, 'r') as f:
            r = list(csv.reader(f))
        header, rows = r[0], r[1:]
        return header, rows

    tr_h, tr_rows = _load_csv(public / "train.csv")
    te_h, te_rows = _load_csv(public / "test.csv")
    assert tr_h == ["image_id", "height", "width"], "train.csv header must be image_id,height,width"
    assert te_h == ["image_id", "height", "width"], "test.csv header must be image_id,height,width"

    # Files exist and sizes match
    for image_id, h, w in tr_rows:
        pimg = public / "train/images" / f"{image_id}.tif"
        pmsk = public / "train/masks" / f"{image_id}.png"
        assert pimg.exists(), f"Missing train image {pimg}"
        assert pmsk.exists(), f"Missing train mask {pmsk}"
        with Image.open(pimg) as im:
            iw, ih = im.size
        with Image.open(pmsk) as mm:
            mh, mw = np.array(mm).shape
        assert iw == int(w) and ih == int(h), f"Train image size mismatch for {image_id}"
        assert mw == int(w) and mh == int(h), f"Train mask size mismatch for {image_id}"
        m = np.array(Image.open(pmsk))
        u = set(np.unique(m).tolist())
        assert u.issubset(VALID_CLASS_VALUES), (
            f"Train mask {image_id} has invalid classes {sorted(list(u))}"
        )

    # Test files and answer alignment
    ta_h, ta_rows = _load_csv(private / "test_answer.csv")
    ss_h, ss_rows = _load_csv(public / "sample_submission.csv")
    assert ta_h == ["image_id", "class_id", "encoding"], "test_answer.csv header mismatch"
    assert ss_h == ["image_id", "class_id", "encoding"], "sample_submission.csv header mismatch"

    test_ids = [row[0] for row in te_rows]
    ans_ids = sorted(set(row[0] for row in ta_rows))
    assert sorted(test_ids) == ans_ids, "test.csv and test_answer.csv image ids mismatch"

    # Ensure six rows per image in answer and sample
    from collections import Counter
    c_ans = Counter([row[0] for row in ta_rows])
    c_ss = Counter([row[0] for row in ss_rows])
    for tid in test_ids:
        assert c_ans[tid] == 6, f"test_answer should have 6 rows per image for {tid}"
        assert c_ss[tid] == 6, f"sample_submission should have 6 rows per image for {tid}"

    # Filenames are anonymized (no city names or label hints)
    def _no_leak(name: str) -> bool:
        bad = ["potsdam", "vaihingen", "label", "rgb", "images", "labels"]
        low = name.lower()
        return not any(b in low for b in bad)

    for image_id, _, _ in tr_rows + te_rows:
        assert image_id.startswith("tile_") and _no_leak(image_id), (
            f"Image id not anonymized: {image_id}"
        )

    # Class coverage: every class in test must appear in train at least once
    train_class_union = set()
    for image_id, _, _ in tr_rows:
        m = np.array(Image.open(public / "train/masks" / f"{image_id}.png"))
        train_class_union |= set(np.unique(m).tolist())

    test_class_union = set()
    te_dims = {row[0]: (int(row[1]), int(row[2])) for row in te_rows}
    ans_group: Dict[str, List[Tuple[int, str]]] = {}
    for image_id, cid, enc in ta_rows:
        ans_group.setdefault(image_id, []).append((int(cid), enc))
    for tid, pairs in ans_group.items():
        h = te_dims[tid][0]
        w = te_dims[tid][1]
        present = set()
        for cid, enc in pairs:
            # Detect presence by any non-empty encoding (do not need decoding here)
            if isinstance(enc, str) and enc.strip() != "":
                present.add(cid)
        test_class_union |= present
    assert test_class_union.issubset(train_class_union), (
        f"All test classes must appear in train. Missing: {sorted(list(test_class_union - train_class_union))}"
    )
