# rohban_dataset.py

import re
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import Resize


def rohban_to_6ch(batch_5ch: torch.Tensor) -> torch.Tensor:
    """
    Map Rohban’s (DNA, Mito, AGP, ER, RNA) → RxRx1‑style 6‑channel stack:

        w1  nuclei / DNA     ← DNA
        w2  ER               ← ER
        w3  actin            ← AGP
        w4  nucleoli / RNA   ← RNA
        w5  mitochondria     ← Mito
        w6  Golgi            ← AGP  (duplicate)

    batch_5ch: (B,5,H,W) float32/16 in [0,1]
    returns   : (B,6,H,W)
    """
    dna, mito, agp, er, rna = batch_5ch.split(1, dim=1)
    return torch.cat([dna, er, agp, rna, mito, agp], dim=1)


##############################################################################
# 1.  CONFIG –– adjust only these if your paths differ
##############################################################################
WORKSPACE = Path("/mnt/pvc/AutoSync/data/cpg0017/broad/workspace")
LOCAL_PREFIX = "/mnt/pvc/AutoSync/data/cpg0017"  # replaces “s3://cellpainting-gallery/”
SIGMA_PLATE_DIRS = ["41744", "41754", "41755", "41757", "41756"]  # order irrelevant
CHANNELS = ["DNA", "Mito", "AGP", "ER", "RNA"]


##############################################################################
# 2.  MERGE helper ─ (identical to the one you validated)
##############################################################################
def merged_sigma2_pilot(WORKSPACE) -> pd.DataFrame:
    pilot = WORKSPACE / "load_data_csv/2013_10_11_SIGMA2_Pilot"
    barcode = pd.read_csv(
        WORKSPACE / "metadata/platemaps/2013_10_11_SIGMA2_Pilot/barcode_platemap.csv"
    )
    rows = []
    for plate in SIGMA_PLATE_DIRS:
        df = pd.read_csv(pilot / plate / "load_data.csv")
        pm_name = barcode.loc[
            barcode.Assay_Plate_Barcode == int(plate), "Plate_Map_Name"
        ].iat[0]
        pm = pd.read_csv(
            WORKSPACE
            / "metadata/platemaps/2013_10_11_SIGMA2_Pilot/platemap"
            / f"{pm_name}_2.txt",
            sep="\t",
        )
        rows.append(df.merge(pm, left_on="Metadata_Well", right_on="well_position"))
    return pd.concat(rows, ignore_index=True)


##############################################################################
# 3.  DATASET
##############################################################################
class RohbanDataset(Dataset):
    """
    Returns (img_tensor, label_id)

      • img_tensor: FloatTensor shape (5,H,W) ∈ [0,1]
      • label_id  : int   (0 = CONTROL when keep_controls=True & one_control_class=True)

    Modes
    -----
    - "full"                : every perturbation (327 labels incl. controls)
    - "morphdiff_exp_5"     : only the 5-gene set used in MorphoDiff paper
    - "morphdiff_exp_12"    : the 12-gene cluster set

    Set collapse_variants=True to collapse alleles / WT replicates to a
    single *gene* label (used by MorphoDiff when they say “5 genes”, “12 genes”).
    """

    MD5_GENES = ["rac1", "kras", "cdc42", "rhoa", "pak1"]

    MD12_GENES = [
        "xbp1",
        "mapk14",
        "rac1",
        "akt1",
        "akt3",
        "rhoa",
        "prkaca",
        "smad4",
        "rps6kb1",
        "kras",
        "braf",
        "raf1",
    ]

    CONTROL_NAMES = {
        "EMPTY_",
        "EMPTY",
        "GFP",
        "BFP",
        "CHERRY",
        "GFP_WT",
        "BFP_WT",
        "CHERRY_WT",
    }

    def __init__(
        self,
        data_dir: Path = WORKSPACE,
        mode: str = "full",
        resize: int | None = None,
        keep_controls: bool = True,
        one_control_class: bool = True,
        collapse_variants: bool = True,
    ):
        assert mode in {"full", "morphdiff_exp_5", "morphdiff_exp_12"}
        df = merged_sigma2_pilot(data_dir)
        # 0. ensure perturbation strings are str
        df["pert_name"] = df["pert_name"].astype(str)
        df["gene_name"] = df["gene_name"].astype(str)

        if mode.startswith("morphdiff"):
            wanted = self.MD5_GENES if mode == "morphdiff_exp_5" else self.MD12_GENES
            mask = df["gene_name"].str.lower().isin(wanted)
            if keep_controls:
                mask |= df["pert_name"].isin(self.CONTROL_NAMES)
            df = df[mask].copy()

        # ------------------------------------------------------------------ #
        # 2.  Handle controls (before building label)
        # ------------------------------------------------------------------ #
        if not keep_controls:
            df = df[~df["pert_name"].isin(self.CONTROL_NAMES)].copy()
        elif one_control_class:
            ctrl_mask = df["pert_name"].isin(self.CONTROL_NAMES)
            df.loc[ctrl_mask, "pert_name"] = "CONTROL"
            df.loc[ctrl_mask, "gene_name"] = "CONTROL"

        # ------------------------------------------------------------------ #
        # 3.  Build canonical label column
        # ------------------------------------------------------------------ #
        def canon_gene(p: str) -> str:
            # Strip everything after first _ or . and lowercase
            return re.split(r"[_.]", p, 1)[0].lower()

        if collapse_variants:
            df["label"] = df["pert_name"].apply(canon_gene)
        else:
            df["label"] = df["pert_name"]

        # 4.  Build label maps FROM label column
        uniq = sorted(df["label"].unique())
        self.pert2id = {p: i for i, p in enumerate(uniq)}
        self.id2pert = {i: p for p, i in self.pert2id.items()}

        # 5.  store
        self.df = df.reset_index(drop=True)
        self.resize = Resize((resize, resize)) if resize else None

        # ── 3 b.  extra bookkeeping  ──────────────────────────────────────────
        self.df["perturbation_id"] = self.df["label"].map(self.pert2id).astype(int)
        self.df["cell_type_id"]    = 1      
    
    # ------------ helpers -------------------------------------------------
    def _s3_to_local(self, url: str) -> str:
        # The load_data CSV stores full S3 URLs; replace with local prefix
        return url.replace(
            "s3://cellpainting-gallery/cpg0017-rohban-pathways", LOCAL_PREFIX, 1
        )

    def _load_5ch(self, row: pd.Series) -> torch.Tensor:
        chans = []
        for ch in CHANNELS:
            path = self._s3_to_local(row[f"URL_Orig{ch}"])
            arr = np.asarray(Image.open(path), np.float32) / 65535.0  # 16-bit → [0,1]
            chans.append(arr)
        img = np.stack(chans, 0)  # (C,H,W)
        t = torch.from_numpy(img)
        if self.resize:
            # torchvision Resize expects (C,H,W) float tensor
            t = self.resize(t)
        return t

    # ------------ Dataset interface --------------------------------------
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        x5 = self._load_5ch(row)
        x6 = rohban_to_6ch(x5.unsqueeze(0))       # (1,6,H,W)
        x6 = x6.squeeze(0)                        # (6,H,W)
        y = self.pert2id[row["label"]]  # <-- use canonical label
        return x6, y, 1


##############################################################################
# 4.  QUICK TEST / USAGE
##############################################################################
if __name__ == "__main__":

    for mode in ("full", "morphdiff_exp_5", "morphdiff_exp_12"):
        ds = RohbanDataset(
            data_dir=WORKSPACE, mode=mode, resize=512, collapse_variants=True
        )

        print(f"{mode:<16}  imgs={len(ds):6}  classes={len(ds.pert2id):3}")
        # print(ds.pert2id)

        x, y, _ = ds[0]
        print(" sample:", x.shape, x.dtype, x.min().item(), x.max().item(), y)
