import os
import shutil
from pathlib import Path
import random
from typing import Tuple

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit

# Deterministic seed
RNG_SEED = 42
TEST_FRACTION = 0.2

VALID_CLASSES = {
    "Lung Opacity",
    "Normal",
    "No Lung Opacity / Not Normal",
}


def _link_or_copy(src: Path, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    try:
        os.link(src, dst)
    except Exception:
        shutil.copy2(src, dst)


def _deterministic_ids(patient_ids) -> dict:
    rng = random.Random(RNG_SEED)
    ids = list(patient_ids)
    rng.shuffle(ids)
    return {pid: f"img_{i:06d}" for i, pid in enumerate(ids)}


def _rle_encode(mask: np.ndarray) -> str:
    if mask.ndim != 2:
        raise ValueError("Mask must be 2D")
    h, w = mask.shape
    pixels = mask.T.flatten()  # column-major
    padded = np.concatenate([[0], pixels.astype(np.uint8), [0]])
    runs = np.where(padded[1:] != padded[:-1])[0] + 1
    starts = runs[0::2]
    ends = runs[1::2]
    lengths = ends - starts
    if len(starts) == 0:
        return ""
    return " ".join(str(x) for pair in zip(starts, lengths) for x in pair)


def _read_mask(path: Path) -> np.ndarray:
    with Image.open(path) as m:
        arr = np.array(m)
    return (arr > 0).astype(np.uint8)


def _read_train_metadata(raw: Path) -> Tuple[pd.DataFrame, Path, Path]:
    train_meta_path = raw / "stage2_train_metadata.csv"
    images_dir = raw / "Training" / "Images"
    masks_dir = raw / "Training" / "Masks"

    df = pd.read_csv(train_meta_path)

    req_cols = ["patientId", "Target", "class", "age", "sex", "modality", "position"]
    missing = [c for c in req_cols if c not in df.columns]
    if missing:
        raise RuntimeError(f"Missing required columns in train metadata: {missing}")

    # Aggregate by image id (patientId)
    agg = (
        df.sort_values(["patientId"]).groupby("patientId", as_index=False).agg(
            {
                "Target": "max",
                "class": "first",
                "age": "first",
                "sex": "first",
                "modality": "first",
                "position": "first",
            }
        )
    )

    if not set(agg["class"]).issubset(VALID_CLASSES):
        extra = set(agg["class"]) - VALID_CLASSES
        raise RuntimeError(f"Unexpected classes encountered: {extra}")

    img_files = set(p.stem for p in images_dir.glob("*.png"))
    msk_files = set(p.stem for p in masks_dir.glob("*.png"))

    agg = agg[agg["patientId"].isin(img_files)].reset_index(drop=True)

    missing_masks = [pid for pid in agg["patientId"] if pid not in msk_files]
    if missing_masks:
        raise RuntimeError(
            f"Missing masks for {len(missing_masks)} images, examples: {missing_masks[:5]}"
        )

    # Add image size
    heights = []
    widths = []
    for pid in agg["patientId"]:
        im_path = images_dir / f"{pid}.png"
        with Image.open(im_path) as img:
            w, h = img.size
        heights.append(h)
        widths.append(w)
    agg["height"] = heights
    agg["width"] = widths

    return agg, images_dir, masks_dir


def prepare(raw: Path, public: Path, private: Path):
    # Deterministic behavior
    random.seed(RNG_SEED)
    np.random.seed(RNG_SEED)

    raw = raw.resolve()
    public = public.resolve()
    private = private.resolve()

    # Clean output dirs
    if public.exists():
        shutil.rmtree(public)
    if private.exists():
        shutil.rmtree(private)
    public.mkdir(parents=True, exist_ok=True)
    private.mkdir(parents=True, exist_ok=True)

    # Read and validate metadata
    meta, raw_img_dir, raw_msk_dir = _read_train_metadata(raw)

    # Stratified deterministic split
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=TEST_FRACTION, random_state=RNG_SEED)
    y = meta["class"].values
    idx_tr, idx_te = next(splitter.split(meta, y))
    meta_train = meta.iloc[idx_tr].reset_index(drop=True)
    meta_test = meta.iloc[idx_te].reset_index(drop=True)

    # Basic sanity: each class present in both splits
    classes_all = set(meta["class"].unique())
    assert set(meta_train["class"]) == classes_all
    assert set(meta_test["class"]) == classes_all

    # Map to anonymized ids
    id_map = _deterministic_ids(meta["patientId"].tolist())

    # Output dirs
    train_images_out = public / "train_images"
    train_masks_out = public / "train_masks"
    test_images_out = public / "test_images"
    train_images_out.mkdir(parents=True, exist_ok=True)
    train_masks_out.mkdir(parents=True, exist_ok=True)
    test_images_out.mkdir(parents=True, exist_ok=True)

    # Copy/link files according to split
    for _, row in meta_train.iterrows():
        pid = row["patientId"]
        new_id = id_map[pid]
        _link_or_copy(raw_img_dir / f"{pid}.png", train_images_out / f"{new_id}.png")
        _link_or_copy(raw_msk_dir / f"{pid}.png", train_masks_out / f"{new_id}.png")

    for _, row in meta_test.iterrows():
        pid = row["patientId"]
        new_id = id_map[pid]
        _link_or_copy(raw_img_dir / f"{pid}.png", test_images_out / f"{new_id}.png")

    # Build train.csv (public)
    train_rows = []
    for _, r in meta_train.iterrows():
        new_id = id_map[r["patientId"]]
        train_rows.append(
            {
                "id": new_id,
                "class": r["class"],
                "age": int(r["age"]) if pd.notna(r["age"]) else -1,
                "sex": str(r["sex"]) if pd.notna(r["sex"]) else "Unknown",
                "modality": str(r["modality"]) if pd.notna(r["modality"]) else "Unknown",
                "position": str(r["position"]) if pd.notna(r["position"]) else "Unknown",
                "height": int(r["height"]),
                "width": int(r["width"]),
            }
        )
    train_df_out = pd.DataFrame(train_rows)
    (public / "train.csv").write_text(train_df_out.to_csv(index=False))

    # Build test.csv (public) and test_answer.csv (private)
    test_rows = []
    ans_rows = []
    for _, r in meta_test.iterrows():
        pid = r["patientId"]
        new_id = id_map[pid]
        h, w = int(r["height"]), int(r["width"])
        test_rows.append(
            {
                "id": new_id,
                "age": int(r["age"]) if pd.notna(r["age"]) else -1,
                "sex": str(r["sex"]) if pd.notna(r["sex"]) else "Unknown",
                "modality": str(r["modality"]) if pd.notna(r["modality"]) else "Unknown",
                "position": str(r["position"]) if pd.notna(r["position"]) else "Unknown",
                "height": h,
                "width": w,
            }
        )
        mask = _read_mask(raw_msk_dir / f"{pid}.png")
        assert mask.shape == (h, w), f"Mask shape mismatch for {pid}: {mask.shape} vs ({h},{w})"
        rle = _rle_encode(mask)
        ans_rows.append(
            {
                "id": new_id,
                "class": r["class"],
                "height": h,
                "width": w,
                "mask_rle": rle,
            }
        )

    test_df_out = pd.DataFrame(test_rows)
    (public / "test.csv").write_text(test_df_out.to_csv(index=False))

    ans_df_out = pd.DataFrame(ans_rows)
    (private / "test_answer.csv").write_text(ans_df_out.to_csv(index=False))

    # Sample submission (public)
    rng = np.random.default_rng(RNG_SEED)
    classes_sorted = sorted(classes_all)
    ss_rows = []
    for _, r in test_df_out.iterrows():
        pred_cls = rng.choice(classes_sorted)
        h, w = int(r["height"]), int(r["width"])
        if pred_cls != "Lung Opacity" or rng.random() < 0.7:
            rle = ""
        else:
            hh = max(1, int(h * rng.uniform(0.03, 0.08)))
            ww = max(1, int(w * rng.uniform(0.03, 0.08)))
            y0 = int(rng.integers(0, max(1, h - hh)))
            x0 = int(rng.integers(0, max(1, w - ww)))
            mask = np.zeros((h, w), dtype=np.uint8)
            mask[y0 : y0 + hh, x0 : x0 + ww] = 1
            rle = _rle_encode(mask)
        ss_rows.append({"id": r["id"], "class": pred_cls, "mask_rle": rle})
    ss_df = pd.DataFrame(ss_rows)
    (public / "sample_submission.csv").write_text(ss_df.to_csv(index=False))

    # Copy description.txt to public/
    root_desc = (public.parent / "description.txt")
    if root_desc.exists():
        shutil.copy2(root_desc, public / "description.txt")

    # Assertions
    train_ids = set(train_df_out["id"]) if not train_df_out.empty else set()
    test_ids = set(test_df_out["id"]) if not test_df_out.empty else set()
    assert train_ids.isdisjoint(test_ids), "Train/Test overlap in IDs"

    # File counts
    assert len(list(train_images_out.glob("*.png"))) == len(train_ids), "Mismatch train images count"
    assert len(list(train_masks_out.glob("*.png"))) == len(train_ids), "Mismatch train masks count"
    assert len(list(test_images_out.glob("*.png"))) == len(test_ids), "Mismatch test images count"

    # Classes coverage
    assert set(train_df_out["class"]) == classes_all, "Training set must cover all classes"

    # Positive masks sanity (sample subset)
    sample_pos = train_df_out[train_df_out["class"] == "Lung Opacity"].head(50)
    for _, r in sample_pos.iterrows():
        with Image.open(train_masks_out / f"{r['id']}.png") as m:
            arr = np.array(m)
        assert (arr > 0).any(), f"Positive class with empty mask: {r['id']}"

    # Sample submission IDs must match test IDs
    assert set(ss_df["id"]) == test_ids, "Sample submission IDs must match test IDs"

    # Hidden answers alignment
    ans_sorted = ans_df_out.sort_values("id").reset_index(drop=True)
    assert ans_sorted["id"].tolist() == sorted(test_ids), "Answer CSV misaligned with test IDs"
