# rohban_datamodule.py
from collections import Counter
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from sc_perturb.rohbandataset import RohbanDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset


class RohbanDataModule:
    """
    Thin convenience wrapper around RohbanDataset:
        • train/val/test splitting  (stratified by perturbation label)
        • PyTorch DataLoaders
        • helpers for class distribution & per‑perturbation subsets
    """

    def __init__(self, cfg: OmegaConf | dict):
        cfg = OmegaConf.create(cfg)  # accept either
        # ------------------------------------------------------------------ #
        # 1.  Core hyper‑parameters
        # ------------------------------------------------------------------ #
        self.data_dir: Path = Path(cfg.data_dir)
        self.mode: str = cfg.get("mode", "full")  # full / morphdiff_…
        self.resize: Optional[int] = cfg.get("resize", None)

        self.keep_controls: bool = cfg.get("keep_controls", True)
        self.one_control_class: bool = cfg.get("one_control_class", True)
        self.collapse_variants: bool = cfg.get("collapse_variants", True)

        self.holdout_ratio: float = cfg.get("holdout_ratio", 0.1)
        self.batch_size: int = cfg.get("batch_size", 8)
        self.num_workers: int = cfg.get("num_workers", 4)
        self.seed: int = cfg.get("seed", 42)
        self.shuffle: bool = cfg.get("shuffle", True)

        # ------------------------------------------------------------------ #
        # 2.  Build the underlying dataset
        # ------------------------------------------------------------------ #
        self.dataset = RohbanDataset(
            data_dir=self.data_dir,
            mode=self.mode,
            resize=self.resize,
            keep_controls=self.keep_controls,
            one_control_class=self.one_control_class,
            collapse_variants=self.collapse_variants,
        )
        self.pert2id = self.dataset.pert2id  # expose for convenience

        # ------------------------------------------------------------------ #
        # 3.  Split indices
        # ------------------------------------------------------------------ #
        self.train_indices: list[int] = []
        self.val_indices: list[int] = []
        self.test_indices: list[int] = []

        self._setup_split()

    # ---------------------------------------------------------------------- #
    # Split helper
    # ---------------------------------------------------------------------- #
    def _setup_split(self):
        indices = np.arange(len(self.dataset))
        if self.holdout_ratio > 0:
            y = [self.dataset.df.loc[i, "label"] for i in indices]  # stratify by label
            self.train_indices, self.val_indices = train_test_split(
                indices,
                test_size=self.holdout_ratio,
                random_state=self.seed,
                shuffle=self.shuffle,
                stratify=y,
            )
        else:
            self.train_indices = indices
            self.val_indices = []

        # for now use the same set for testing
        self.test_indices = self.val_indices

    # ---------------------------------------------------------------------- #
    # Dataset shortcuts
    # ---------------------------------------------------------------------- #
    def get_train_dataset(self) -> Subset:
        return Subset(self.dataset, self.train_indices)

    def get_val_dataset(self) -> Optional[Subset]:
        return (
            Subset(self.dataset, self.val_indices)
            if len(self.val_indices) > 0
            else None
        )

    def get_test_dataset(self) -> Optional[Subset]:
        return (
            Subset(self.dataset, self.test_indices)
            if len(self.test_indices) > 0
            else None
        )

    # ---------------------------------------------------------------------- #
    # Dataloaders
    # ---------------------------------------------------------------------- #
    def get_train_loader(self) -> DataLoader:
        return DataLoader(
            self.get_train_dataset(),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_val_loader(self) -> Optional[DataLoader]:
        val_ds = self.get_val_dataset()
        if val_ds is None:
            return None
        return DataLoader(
            val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_test_loader(self) -> Optional[DataLoader]:
        test_ds = self.get_test_dataset()
        if test_ds is None:
            return None
        return DataLoader(
            test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    # ---------------------------------------------------------------------- #
    # Diagnostics & utilities
    # ---------------------------------------------------------------------- #
    def get_perturbation_distribution(self) -> dict[int, int]:
        labels = [self.pert2id[lbl] for lbl in self.dataset.df["label"]]
        return dict(Counter(labels))

    def filter_samples_per_perturbation(self, perturbation_id: int) -> Subset:
        idxs = [
            i
            for i, lbl in enumerate(self.dataset.df["label"])
            if self.pert2id[lbl] == perturbation_id
        ]
        return Subset(self.dataset, idxs)

        # ---------- metadata helpers ------------------------------------ #

    def _as_metadata_df(self) -> pd.DataFrame:
        """Return a lightweight DataFrame with the few columns we need."""
        # Keep only columns we actually use → smaller, easier to reason about
        cols = ["label", "pert_name", "gene_name", "plate", "well"]
        keep = [c for c in cols if c in self.dataset.df.columns]
        return self.dataset.df[keep].copy()

    # ---------------------------------------------------------------- #
    #   1.  get_perturbation_distribution / cell‑type / joint
    # ---------------------------------------------------------------- #
    def get_perturbation_distribution(self) -> dict[str, float]:
        """Fraction of images per perturbation (label column)."""
        counts = self._as_metadata_df()["label"].value_counts()
        total = counts.sum()
        return {k: (v / total) for k, v in counts.items()}

    def get_cell_type_distribution(self) -> dict[str, float]:
        """
        Rohban SIGMA‑2 pilot is entirely U2OS, so return a single entry.
        Keeps the API symmetric with RxRx1.
        """
        return {"U2OS": 1.0}

    def get_perturbation_cell_type_distribution(self) -> dict[Tuple[str, str], float]:
        """Joint distribution.  Only one cell type (U2OS)."""
        pert_dist = self.get_perturbation_distribution()
        return {(p, "U2OS"): f for p, f in pert_dist.items()}

    # ---------------------------------------------------------------- #
    # 2.  Metadata filtering helper  (INT IDs: perturbation_id, cell_type_id)
    # ---------------------------------------------------------------- #
    def filter_metadata(
        self,
        perturbation_id: Optional[int] = None,
        cell_type_id: Optional[int] = None,
    ) -> pd.DataFrame:
        """
        Slice self.dataset.df by integer IDs.

        • perturbation_id – value from the `perturbation_id` column
        • cell_type_id    – value from the `cell_type_id` column   (always 1 for Rohban)
        """
        md = self.dataset.df
        if perturbation_id is not None:
            md = md[md["perturbation_id"] == perturbation_id]
        if cell_type_id is not None:
            md = md[md["cell_type_id"] == cell_type_id]
        return md

    # ---------------------------------------------------------------- #
    # 3.  filter_samples / filter_samples_by_perturbation
    # ---------------------------------------------------------------- #
    def _subset_from_metadata(self, md: pd.DataFrame) -> Subset:
        """Convert a metadata slice back to a Subset of the dataset."""
        return Subset(self.dataset, md.index.to_numpy())

    def filter_samples(
        self,
        perturbation_id: Optional[int] = None,
        cell_type_id: Optional[int] = None,
        num_samples: Optional[int] = None,
        seed: int = 42,
    ) -> Optional[Subset]:
        """
        Return a Subset that matches the filters, optionally
        down‑/upsampling to *exactly* num_samples.
        """
        md = self.filter_metadata(perturbation_id, cell_type_id)
        if md.empty:
            return None

        if num_samples is not None:
            rng = np.random.default_rng(seed)
            if len(md) >= num_samples:
                md = md.sample(n=num_samples, random_state=seed)
            else:  # up‑sample with replacement
                extra = md.sample(
                    n=num_samples - len(md),
                    replace=True,
                    random_state=seed,
                )
                md = pd.concat([md, extra])

        return self._subset_from_metadata(md)

    def filter_samples_by_perturbation(
        self,
        perturbation_id: int,
        num_samples: int = 100,
        seed: int = 42,
    ) -> Optional[Subset]:
        """Shorthand for the common single‑perturbation query."""
        return self.filter_samples(
            perturbation_id=perturbation_id,
            num_samples=num_samples,
            seed=seed,
        )

    # ---------------------------------------------------------------- #
    # 4.  get_filtered_loader  (train / val / test)
    # ---------------------------------------------------------------- #
    def get_filtered_loader(
        self,
        split: str = "train",
        perturbation_id: Optional[int] = None,
        cell_type_id: Optional[int] = None,
        batch_size: Optional[int] = None,
        num_samples: Optional[int] = None,
        seed: int = 42,
    ) -> Optional[DataLoader]:

        split = split.lower()
        if split not in {"train", "val", "test"}:
            raise ValueError("split must be 'train', 'val', or 'test'")

        if batch_size is None:
            batch_size = self.batch_size

        split_idx = {
            "train": self.train_indices,
            "val": self.val_indices,
            "test": self.test_indices,
        }[split]

        md = self.dataset.df.loc[split_idx]
        md = self.filter_metadata(perturbation_id, cell_type_id).loc[md.index]
        if md.empty:
            return None

        if num_samples is not None:
            md = md.sample(
                n=min(num_samples, len(md)),
                replace=len(md) < num_samples,
                random_state=seed,
            )

        ds = self._subset_from_metadata(md)
        shuffle = split == "train"

        return DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=shuffle,
        )


# -------------------------------------------------------------------------- #
# Example — quick smoke test
# -------------------------------------------------------------------------- #
if __name__ == "__main__":
    cfg = {
        "data_dir": "/mnt/pvc/AutoSync/data/cpg0017/broad/workspace",
        "mode": "morphdiff_exp_5",  # or "full", "morphdiff_exp_5"
        "resize": 512,
        "keep_controls": True,
        "one_control_class": True,
        "collapse_variants": True,
        "holdout_ratio": 0.1,
        "batch_size": 4,
        "num_workers": 4,
        "seed": 42,
        "shuffle": True,
    }

    dm = RohbanDataModule(cfg)

    train_loader = dm.get_train_loader()
    val_loader = dm.get_val_loader()

    print("class distribution:", dm.get_perturbation_distribution())

    # per class as counts
    for pert_id, prob in dm.get_perturbation_distribution().items():
        print(f"Perturbation {pert_id}: count = {prob * len(train_loader.dataset):.0f}")
    first_batch = next(iter(train_loader))
    imgs, labels, _ = first_batch
    print("batch:", imgs.shape, labels.shape)
