import glob
import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import PIL.Image
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from omegaconf import DictConfig
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from torchvision.transforms import functional as TF

try:
    import pyspng
except ImportError:
    pyspng = None

DEFAULT_CHANNELS = (0, 1, 2, 3, 4, 5)
RGB_MAP = {
    0: {"rgb": np.array([19, 0, 249]), "range": [0, 51]},
    1: {"rgb": np.array([42, 255, 31]), "range": [0, 107]},
    2: {"rgb": np.array([255, 0, 25]), "range": [0, 64]},
    3: {"rgb": np.array([45, 255, 252]), "range": [0, 191]},
    4: {"rgb": np.array([250, 0, 253]), "range": [0, 89]},
    5: {"rgb": np.array([254, 255, 40]), "range": [0, 191]},
}


def rescale_intensity(
    arr: torch.Tensor, bounds=(0.5, 99.5), out_range=(0.0, 1.0)
) -> torch.Tensor:
    arr = arr.float() / 255
    sample = arr.flatten()[::100]
    percentiles = torch.quantile(
        sample, torch.tensor([bounds[0] / 100.0, bounds[1] / 100.0])
    )
    arr = torch.clamp(arr, percentiles[0], percentiles[1])
    arr = (arr - percentiles[0]) / (percentiles[1] - percentiles[0])
    arr = arr * (out_range[1] - out_range[0]) + out_range[0]
    return arr


def to_rgb(img: torch.Tensor, dtype=torch.float32) -> torch.Tensor:  # type: ignore[no-untyped-def]
    """Reference: https://github.com/recursionpharma/rxrx1-utils/blob/d34b2b0db0af1cb4fe357573bb8de76bd042b34f/rxrx/io.py#L61
    Converts a typical photosynthetic image into RGB, assuming the 6-channel cellpaint standard.
    NOTE: to run on CUDA, use the float16 dtype (default assumption); to run on CPU, requires float32.
    """
    num_channels_required = 6
    b, num_channels, length, width = img.shape  # b x c x l x w
    prepped_img = torch.zeros(
        b, num_channels_required, length, width, dtype=img.dtype, device=img.device
    )
    if num_channels < num_channels_required:
        prepped_img[:, :num_channels, :, :] += img
    elif num_channels > num_channels_required:
        prepped_img += img[:, :num_channels_required, :, :]
    else:
        prepped_img += img
    # color mapping
    red = [1, 0, 0]
    green = [0, 1, 0]
    blue = [0, 0, 1]
    yellow = [1, 1, 0]
    magenta = [1, 0, 1]
    cyan = [0, 1, 1]
    rgb_map = torch.tensor(
        [blue, green, red, cyan, magenta, yellow],
        dtype=dtype,
        device=prepped_img.device,
    )
    rgb_img: torch.FloatTensor = (
        torch.einsum(  # type: ignore[assignment]
            "nchw,ct->nthw",
            prepped_img.to(dtype=dtype),
            rgb_map,
        )
        / 3.0
    )
    _max = rgb_img.max().item()
    _min = rgb_img.min().item()
    return rescale_intensity(rgb_img, bounds=(0.1, 99.9))


class RandomExactRotation:
    def __init__(self, angles, p=0.5):
        self.angles = angles  # List of exact angles (e.g., [90, 180, 270])
        self.p = p  # Probability of applying the rotation

    def __call__(self, img):
        if np.random.rand() < self.p:  # Apply with probability p
            angle = int(np.random.choice(self.angles))
            return TF.rotate(img, angle)
        return img  # Return the original image if not applied


class CustomTransform:
    """Class for scaling, resizing, augmenting, and normalizing input images."""

    def __init__(
        self,
        augment: Optional[bool] = False,
        normalize: Optional[bool] = False,
        img_size: Optional[Tuple[int, int]] = (512, 512),
        reduce_channels: Optional[bool] = False,
    ):
        """Initialize the CustomTransform instance.

        Args:
            augment (bool, optional): Whether to apply augmentation (random flips, rotations). Defaults to False.
            normalize (bool, optional): Whether to normalize the input image. Defaults to False.
            img_size (Tuple[int, int], optional): Target image size. Defaults to (512, 512).
        """
        self.augment = augment
        self.normalize = normalize
        self.resize_shape = img_size
        self.reduce_channels = reduce_channels

    def self_standardize(self, X: torch.Tensor) -> torch.Tensor:
        """Apply channel-wise self-standardization to the input image.

        Args:
            X (torch.Tensor): Input image tensor (C, H, W).

        Returns:
            torch.Tensor: Standardized image tensor.
        """
        # Compute mean and std for each channel
        mean = X.mean(dim=(1, 2), keepdim=True)  # Compute mean per channel
        std = (
            X.std(dim=(1, 2), keepdim=True) + 1e-6
        )  # Compute std per channel, add epsilon to avoid division by zero

        # Standardize the image (subtract mean and divide by std for each channel)
        X = (X - mean) / std
        return X

    def __call__(self, X: torch.Tensor) -> torch.Tensor:
        """Apply the transformations to the input image.

        Args:
            X (torch.Tensor): Input image tensor (C, H, W).

        Returns:
            torch.Tensor: Transformed image tensor.
        """
        # Resize the input image to the target size
        t = [T.Resize(self.resize_shape, interpolation=Image.BICUBIC)]

        # Perform augmentation steps
        # adjust p such that (1-p)^3 = 0.5
        if self.augment:
            t.append(T.RandomHorizontalFlip(p=0.2))  # Horizontal flip
            t.append(T.RandomVerticalFlip(p=0.2))  # Vertical flip
            t.append(
                RandomExactRotation(angles=[90, 180, 270], p=0.2)
            )  # 90 degree rotations

        # Compose the transformations
        transform = T.Compose(t)

        # Apply resizing and augmentations (if any)
        if not self.reduce_channels:
            X /= 255.0  # Normalize to [0, 1] range
        X = transform(X)
        # Perform self-standardization
        if self.normalize:
            X = self.self_standardize(X)

        return X


class CellDataset(Dataset):
    def __init__(
        self,
        metadata: pd.DataFrame,
        img_size: Optional[Tuple[int, int]] = (224, 224),
        resize: Optional[Tuple[int, int]] = (224, 224),
        reduce_channels: Optional[bool] = False,
        augment: Optional[bool] = False,
        normalize: Optional[bool] = False,
        use_numpy: Optional[bool] = True,
        use_parquet: Optional[bool] = True,
        split: Optional[str] = "train",
    ):
        self.metadata = metadata
        self.image_paths = metadata["image_paths"].values
        self.sirna_ids = metadata["sirna_id"].values
        self.cell_types = metadata["cell_type"].values
        # convert cell_types to labels
        self.cell_type_to_label = {
            "HEPG2": 0,
            "HUVEC": 1,
            "RPE": 2,
            "U2OS": 3,
        }
        self.cell_types = [
            self.cell_type_to_label[cell_type] for cell_type in self.cell_types
        ]
        self.use_numpy = use_numpy
        self.use_parquet = use_parquet
        self.img_size = img_size
        self.augment = augment
        self.normalize = normalize
        self.resize = resize
        self.reduce_channels = reduce_channels

        if use_numpy:
            self.numpy_paths = metadata["numpy_path"].values
        if use_parquet:
            self.parquet_paths = metadata["parquet_path"].values

        # to tensor transform
        # self.transform = trn.Compose([
        #     trn.ToTensor(),
        #     trn.Normalize(mean=[0.5], std=[0.5])
        # ])
        if split == "train":
            self.transform = CustomTransform(
                augment=augment,
                normalize=normalize,
                img_size=self.resize,
                reduce_channels=reduce_channels,
            )
        else:
            self.transform = CustomTransform(
                augment=False,
                normalize=normalize,
                img_size=self.resize,
                reduce_channels=reduce_channels,
            )

    def __len__(self) -> int:
        return len(self.metadata)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
        # row = self.metadata.iloc[idx]
        if self.use_parquet:
            # Load from .parquet file
            # table = pq.read_table(self.parquet_paths[idx])
            # sample = np.array(table.column("data")).reshape(6, self.img_size, self.img_size)
            with pa.memory_map(self.parquet_paths[idx], "r") as source:
                table = pq.read_table(source)
                sample = np.array(table.column("data")).reshape(6, *self.img_size)

        elif self.use_numpy:
            # Load from .npy file
            sample = np.load(self.numpy_paths[idx], mmap_mode="r")
            sample = np.copy(sample)
        else:
            # Load and stack from PNG files
            images = []
            for image_path in eval(self.image_paths[idx]):
                image = Image.open(image_path)
                images.append(np.array(image))
            sample = np.stack(images, axis=0)  # Shape (6, H, W)
        sample = torch.from_numpy(sample).float()
        if self.reduce_channels:
            sample = to_rgb(sample.unsqueeze(0))[0]
        # sample = torch.from_numpy(sample).float()
        sample = self.transform(sample)
        y = torch.tensor(self.sirna_ids[idx], dtype=torch.long)
        cell_type = torch.tensor(self.cell_types[idx], dtype=torch.long)
        sample_dict = {
            "image": sample,
            "sirna_id": y,
            "cell_type": cell_type,
        }
        return sample, y, cell_type


def get_loader(
    metadata: pd.DataFrame,
    use_numpy: bool,
    img_size: Tuple[int, int],
    augment: bool,
    normalize: bool,
    resize: Tuple[int, int],
    reduce_channels: bool,
    split: str,
    batch_size: int,
    train_contrastive: bool,
    moco: bool,
    num_workers: int,
    shuffle: bool,
) -> DataLoader:
    dataloader_kwargs = {
        "num_workers": num_workers,
        "pin_memory": True,
        "persistent_workers": True if num_workers > 0 else False,
    }
    dataset = CellDataset(
        metadata,
        use_numpy=use_numpy,
        img_size=img_size,
        augment=augment,
        normalize=normalize,
        resize=resize,
        reduce_channels=reduce_channels,
        split=split,
    )

    if train_contrastive and not moco:
        num_positives = 4
        num_classes_per_batch = batch_size // num_positives
        sampler = ClassBalancedBatchSampler(
            metadata["sirna_id"],
            num_classes=num_classes_per_batch,
            num_samples=num_positives,
            drop_last=True,
        )
        dataloader_kwargs["batch_sampler"] = sampler
    else:
        dataloader_kwargs["batch_size"] = batch_size
        dataloader_kwargs["shuffle"] = shuffle

    return DataLoader(dataset, **dataloader_kwargs)


class CellDataModule:
    def __init__(self, args: DictConfig):
        self.data_dir = args.data_dir
        self.img_path = os.path.join(self.data_dir, "images")
        self.numpy_img_path = os.path.join(self.data_dir, "numpy_images")
        self.parquet_path = os.path.join(self.data_dir, "parquet_data")

        self.metadata_file = os.path.join(self.data_dir, "metadata.csv")
        self.image_embed_file = os.path.join(self.data_dir, "embeddings.csv")
        self.save_numpy = args.save_numpy  # Boolean flag to determine operation mode
        self.save_parquet = args.save_parquet
        self.batch_size = args.batch_size
        self.num_workers = args.num_workers
        self.img_size = args.img_size
        self.augment = args.augment
        self.normalize = args.normalize
        self.resize = args.resize
        self.train_contrastive = args.train_contrastive
        self.moco = args.moco
        self.reduce_channels = args.reduce_channels
        # Read metadata
        self.metadata = pd.read_csv(self.metadata_file)
        self.cell_type_to_label = {
            "HEPG2": 0,
            "HUVEC": 1,
            "RPE": 2,
            "U2OS": 3,
        }

        # Check for existing extended metadata
        self.extended_metadata_file = os.path.join(
            self.data_dir, "metadata_extended.csv"
        )
        if os.path.exists(self.extended_metadata_file):
            self.metadata = pd.read_csv(self.extended_metadata_file)
        else:
            os.makedirs(self.numpy_img_path, exist_ok=True)
            os.makedirs(self.parquet_path, exist_ok=True)

        self.setup(args.holdout_ratio, args.seed)

    def setup(
        self, val_size: Optional[float] = 0.2, random_state: Optional[int] = 1337
    ):
        """Prepares the dataset by organizing metadata, mapping image file paths, and reserving train/val indexing."""
        # Reserve train/validation split
        train_metadata = self.metadata[self.metadata["dataset"] == "train"]
        if val_size == 0.0:
            self.metadata["train_index"] = self.metadata["dataset"] == "train"
            self.metadata["val_index"] = False
        else:
            train_index, val_index = train_test_split(
                train_metadata.index,
                test_size=val_size,
                random_state=random_state,
                # stratify=train_metadata["sirna_id"],
            )
            self.metadata["train_index"] = self.metadata.index.isin(train_index)
            self.metadata["val_index"] = self.metadata.index.isin(val_index)

        if os.path.exists(self.extended_metadata_file):
            print("Extended metadata already exists. Skipping setup.")
            return

        # Add column for full image paths in metadata
        def construct_image_paths(row: pd.Series) -> list[str]:
            cell_type_batch = row[
                "experiment"
            ]  # Assuming "experiment" includes cell type and batch info
            plate = f"Plate{row['plate']}"
            file_paths = [
                os.path.join(
                    self.img_path,
                    cell_type_batch,
                    plate,
                    f"{row['well']}_s{row['site']}_w{well_num}.png",
                )
                for well_num in range(1, 7)
            ]
            return file_paths

        self.metadata["image_paths"] = self.metadata.apply(
            construct_image_paths, axis=1
        )

        # Ensure all image paths exist
        all_paths = [path for paths in self.metadata["image_paths"] for path in paths]
        missing_paths = [path for path in all_paths if not os.path.exists(path)]
        if missing_paths:
            raise FileNotFoundError(f"Missing image files: {missing_paths}")

        # Save numpy files if enabled
        if self.save_numpy:
            self.save_all_samples_to_npy()
        if self.save_parquet:
            self.save_all_samples_to_parquet()

    def load_sample(self, index: int) -> Tuple[torch.Tensor, int]:
        row = self.metadata.iloc[index]

        if "numpy_path" in row and os.path.exists(row["numpy_path"]):
            # Load from .npy file
            return torch.from_numpy(np.load(row["numpy_path"])).float()
        else:
            # Load and stack from PNG files
            images = []
            for image_path in row["image_paths"]:
                image = Image.open(image_path)
                images.append(np.array(image))

            return torch.from_numpy(np.stack(images, axis=0)).float()  # Shape (6, H, W)

    def site_id_to_label(self, site_id: str) -> str:
        # find the site_id in metadata and return its sirna_id
        return self.metadata[self.metadata["site_id"] == site_id]["sirna_id"].values[0]

    def save_all_samples_to_npy(self):
        """Loads all samples, saves them as .npy files, and updates metadata with numpy file paths."""
        print("Saving samples as numpy files...")
        numpy_paths = []

        for index in range(len(self.metadata)):
            # Check if numpy file already exists to avoid redundant operations
            row = self.metadata.iloc[index]
            cell_type_batch = row["experiment"]
            plate = f"Plate{row['plate']}"
            sample_dir = os.path.join(self.numpy_img_path, cell_type_batch, plate)
            os.makedirs(sample_dir, exist_ok=True)

            numpy_file_path = os.path.join(
                sample_dir, f"{row['well']}_s{row['site']}.npy"
            )
            numpy_paths.append(numpy_file_path)

            if os.path.exists(numpy_file_path):
                continue

            # Load the sample
            sample = self.load_sample(index)

            # Save the sample as a numpy file
            np.save(numpy_file_path, sample.numpy())

        # Add numpy file paths to metadata
        self.metadata["numpy_path"] = numpy_paths

        # Save the extended metadata to a new CSV
        self.metadata.to_csv(self.extended_metadata_file, index=False)

    def save_all_samples_to_parquet(self):
        """Converts images to Parquet format and updates metadata."""
        print("Converting images to Parquet format...")
        parquet_paths = []

        for index, row in self.metadata.iterrows():
            cell_type_batch = row["experiment"]
            plate = f"Plate{row['plate']}"
            filename = f"{cell_type_batch}_{plate}_{row['well']}_s{row['site']}.parquet"
            parquet_file_path = os.path.join(self.parquet_path, filename)

            parquet_paths.append(parquet_file_path)

            if os.path.exists(parquet_file_path):
                continue

            # Load and preprocess images
            images = []
            for image_path in row["image_paths"]:
                image = Image.open(image_path)
                images.append(np.array(image, dtype=np.uint8))

            stacked_images = np.stack(images, axis=0)  # Shape (6, H, W)
            table = pa.Table.from_arrays(
                [pa.array(stacked_images.reshape(-1))], names=["data"]
            )
            pq.write_table(table, parquet_file_path)

        self.metadata["parquet_path"] = parquet_paths
        self.metadata.to_csv(self.extended_metadata_file, index=False)

    def load_sample_from_parquet(self, index: int) -> torch.Tensor:
        """Loads a single sample lazily from Parquet."""
        row = self.metadata.iloc[index]
        parquet_file_path = row["parquet_path"]

        table = pq.read_table(parquet_file_path)
        stacked_images = np.array(table.column("data")).reshape(
            6, self.img_size, self.img_size
        )
        return torch.from_numpy(stacked_images).float()

    def get_train_loader(self) -> DataLoader:
        """Creates a DataLoader for the training dataset.

        Parameters:
            batch_size (int): Batch size for DataLoader.

        Returns:
            DataLoader: Train DataLoader.
        """
        return get_loader(
            metadata=self.metadata[self.metadata["train_index"]],
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=self.augment,
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="train",
            batch_size=self.batch_size,
            train_contrastive=self.train_contrastive,
            moco=self.moco,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def get_val_loader(self) -> DataLoader:
        """Creates a DataLoader for the validation dataset.

        Parameters:
            batch_size (int): Batch size for DataLoader.

        Returns:
            DataLoader: Validation DataLoader.
        """
        return get_loader(
            metadata=self.metadata[self.metadata["val_index"]],
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=self.augment,
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="val",
            batch_size=self.batch_size,
            train_contrastive=self.train_contrastive,
            moco=self.moco,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def get_test_loader(self) -> DataLoader:
        """Creates a DataLoader for the test dataset.

        Parameters:
            batch_size (int): Batch size for DataLoader.

        Returns:
            DataLoader: Test DataLoader.
        """
        return get_loader(
            metadata=self.metadata[self.metadata["dataset"] == "test"],
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=self.augment,
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="test",
            batch_size=self.batch_size,
            train_contrastive=self.train_contrastive,
            moco=self.moco,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def get_train_dataset(self, perturbation_id=None, cell_type_id=None):
        """Get the training dataset, optionally filtered by perturbation ID and/or cell type.

        Args:
            perturbation_id: Optional specific perturbation ID to filter for
            cell_type_id: Optional specific cell type ID to filter for

        Returns:
            Dataset filtered according to the specified criteria
        """
        train_metadata = self.metadata[self.metadata["train_index"]]

        # Apply filters if specified
        if perturbation_id is not None:
            train_metadata = train_metadata[
                train_metadata["sirna_id"] == perturbation_id
            ]

        if cell_type_id is not None:
            train_metadata = train_metadata[
                train_metadata["cell_type_id"] == cell_type_id
            ]

        # Return empty dataset if no data found with the filters
        if len(train_metadata) == 0:
            print(
                f"No training samples found with perturbation_id={perturbation_id} and cell_type_id={cell_type_id}"
            )
            return None

        return CellDataset(
            metadata=train_metadata,
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=self.augment,
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="train",
        )

    def get_val_dataset(self, perturbation_id=None, cell_type_id=None):
        """Get the validation dataset, optionally filtered by perturbation ID and/or cell type.

        Args:
            perturbation_id: Optional specific perturbation ID to filter for
            cell_type_id: Optional specific cell type ID to filter for

        Returns:
            Dataset filtered according to the specified criteria
        """
        val_metadata = self.metadata[self.metadata["val_index"]]

        # Apply filters if specified
        if perturbation_id is not None:
            val_metadata = val_metadata[val_metadata["sirna_id"] == perturbation_id]

        if cell_type_id is not None:
            val_metadata = val_metadata[val_metadata["cell_type_id"] == cell_type_id]

        # Return empty dataset if no data found with the filters
        if len(val_metadata) == 0:
            print(
                f"No validation samples found with perturbation_id={perturbation_id} and cell_type_id={cell_type_id}"
            )
            return None

        return CellDataset(
            metadata=val_metadata,
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=False,  # No augmentation for validation
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="val",
        )

    def get_test_dataset(self, perturbation_id=None, cell_type_id=None):
        """Get the test dataset, optionally filtered by perturbation ID and/or cell type.

        Args:
            perturbation_id: Optional specific perturbation ID to filter for
            cell_type_id: Optional specific cell type ID to filter for

        Returns:
            Dataset filtered according to the specified criteria
        """
        test_metadata = self.metadata[self.metadata["dataset"] == "test"]

        # Apply filters if specified
        if perturbation_id is not None:
            test_metadata = test_metadata[test_metadata["sirna_id"] == perturbation_id]

        if cell_type_id is not None:
            test_metadata = test_metadata[test_metadata["cell_type_id"] == cell_type_id]

        # Return empty dataset if no data found with the filters
        if len(test_metadata) == 0:
            print(
                f"No test samples found with perturbation_id={perturbation_id} and cell_type_id={cell_type_id}"
            )
            return None

        return CellDataset(
            metadata=test_metadata,
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=False,  # No augmentation for test
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="test",
        )

    def get_filtered_loader(
        self, split="train", perturbation_id=None, cell_type_id=None, batch_size=None
    ):
        """Get a DataLoader for the specified split, filtered by perturbation ID and/or cell type.

        Args:
            split: Which dataset split to use ('train', 'val', or 'test')
            perturbation_id: Optional specific perturbation ID to filter for
            cell_type_id: Optional specific cell type ID to filter for
            batch_size: Optional batch size override

        Returns:
            DataLoader with the filtered dataset
        """
        if batch_size is None:
            batch_size = self.batch_size

        if split == "train":
            dataset = self.get_train_dataset(perturbation_id, cell_type_id)
        elif split == "val":
            dataset = self.get_val_dataset(perturbation_id, cell_type_id)
        elif split == "test":
            dataset = self.get_test_dataset(perturbation_id, cell_type_id)
        else:
            raise ValueError(f"Invalid split: {split}")

        if dataset is None or len(dataset) == 0:
            return None

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=(split == "train"),  # Only shuffle training data
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=(split == "train"),  # Only drop last batch for training
        )

    def get_perturbation_distribution(self):
        """Get the distribution of perturbation IDs in the training set.

        Returns:
            Dictionary mapping perturbation_id to frequency
        """
        train_metadata = self.metadata[self.metadata["train_index"]]
        counts = train_metadata["sirna_id"].value_counts()
        total = counts.sum()
        return {int(k): v / total for k, v in counts.items()}

    def get_cell_type_distribution(self):
        """Get the distribution of cell types in the training set.

        Returns:
            Dictionary mapping cell_type_id to frequency
        """
        train_metadata = self.metadata[self.metadata["train_index"]]
        counts = train_metadata["cell_type_id"].value_counts()
        total = counts.sum()
        return {int(k): v / total for k, v in counts.items()}

    def get_perturbation_cell_type_distribution(self):
        """Get the joint distribution of perturbation IDs and cell types in the training set.

        Returns:
            Dictionary mapping (perturbation_id, cell_type_id) to frequency
        """
        train_metadata = self.metadata[self.metadata["train_index"]]
        grouped = train_metadata.groupby(["sirna_id", "cell_type_id"]).size()
        total = len(train_metadata)
        return {
            (int(idx[0]), int(idx[1])): count / total for idx, count in grouped.items()
        }

    def filter_samples_by_perturbation(
        self, perturbation_id, num_samples=100, device=None, cell_type=None
    ):
        """
        Filter the dataset for a specific perturbation and ensure exactly num_samples are returned.
        Similar to process_perturbation_samples in train.py

        Args:
            perturbation_id: The perturbation ID to filter for
            num_samples: Number of samples to return
            device: Device to place tensors on
            cell_type: Optional cell type ID to further filter

        Returns:
            Tuple of (samples_tensor, metadata)
        """
        dataset = self.get_train_dataset()

        # Filter samples with the specified perturbation ID
        perturbation_samples = []
        metadata = []

        for i in range(len(dataset)):
            sample, meta = dataset[i]
            if meta["sirna_id"] == perturbation_id:
                if cell_type is None or meta["cell_type_id"] == cell_type:
                    perturbation_samples.append(sample)
                    metadata.append(meta)

        if len(perturbation_samples) == 0:
            return None, None

        # Handle case with fewer samples than requested
        if len(perturbation_samples) < num_samples:
            # Randomly repeat samples to reach num_samples
            indices = np.random.choice(
                len(perturbation_samples),
                num_samples - len(perturbation_samples),
                replace=True,
            )
            additional_samples = [perturbation_samples[i] for i in indices]
            additional_metadata = [metadata[i] for i in indices]
            perturbation_samples.extend(additional_samples)
            metadata.extend(additional_metadata)

        # Handle case with more samples than requested
        elif len(perturbation_samples) > num_samples:
            # Randomly select num_samples
            indices = np.random.choice(
                len(perturbation_samples), num_samples, replace=False
            )
            perturbation_samples = [perturbation_samples[i] for i in indices]
            metadata = [metadata[i] for i in indices]

        # Convert to tensor
        samples_tensor = torch.stack(perturbation_samples)
        if device is not None:
            samples_tensor = samples_tensor.to(device)

        return samples_tensor, metadata

    def filter_samples(
        self,
        perturbation_id=None,
        cell_type_id=None,
        num_samples=None,
        seed=42,
    ):
        """
        Filter samples by perturbation ID and/or cell type, with option to limit sample count.

        Args:
            perturbation_id: Optional specific perturbation ID to filter for
            cell_type_id: Optional specific cell type ID to filter for
            num_samples: Optional maximum number of samples to return
            device: Optional device to place tensors on
            seed: Random seed for reproducibility when selecting subset of samples

        Returns:
            Tuple of (samples_tensor, metadata) with filtered samples
        """
        # Set random seed for reproducibility
        np.random.seed(seed)

        # Get training metadata
        train_metadata = self.metadata[self.metadata["train_index"]]

        # Apply filters directly on metadata DataFrame
        filtered_metadata = train_metadata.copy()

        if perturbation_id is not None:
            filtered_metadata = filtered_metadata[
                filtered_metadata["sirna_id"] == perturbation_id
            ]

        if cell_type_id is not None:
            filtered_metadata = filtered_metadata[
                (
                    filtered_metadata["cell_type"]
                    == list(self.cell_type_to_label.keys())[cell_type_id]
                    if isinstance(cell_type_id, int)
                    else cell_type_id
                )
            ]

        if len(filtered_metadata) == 0:
            return None, None

        # Limit sample count if needed
        if num_samples is not None:
            if len(filtered_metadata) > num_samples:
                # Randomly select num_samples
                filtered_metadata = filtered_metadata.sample(
                    n=num_samples, random_state=seed
                )
            elif len(filtered_metadata) < num_samples:
                # Upsample to reach num_samples by repeating
                additional_indices = np.random.choice(
                    filtered_metadata.index,
                    num_samples - len(filtered_metadata),
                    replace=True,
                )
                additional_metadata = self.metadata.loc[additional_indices]
                filtered_metadata = pd.concat([filtered_metadata, additional_metadata])

        # Create dataset with the filtered metadata
        filtered_dataset = CellDataset(
            metadata=filtered_metadata,
            use_numpy=self.save_numpy,
            img_size=self.img_size,
            augment=False,  # No augmentation for filtered samples
            normalize=self.normalize,
            resize=self.resize,
            reduce_channels=self.reduce_channels,
            split="train",
        )
        return filtered_dataset
    
    def filter_metadata(
        self,
        perturbation_id=None,
        cell_type_id=None,
    ):
        """
        Filter metadata by perturbation ID and/or cell type

        Args:
            perturbation_id: Optional specific perturbation ID to filter for
            cell_type_id: Optional specific cell type ID to filter for

        Returns:
            Metadata
        """
        # Get training metadata
        train_metadata = self.metadata[self.metadata["train_index"]]

        # Apply filters directly on metadata DataFrame
        filtered_metadata = train_metadata.copy()

        if perturbation_id is not None:
            filtered_metadata = filtered_metadata[
                filtered_metadata["sirna_id"] == perturbation_id
            ]

        if cell_type_id is not None:
            filtered_metadata = filtered_metadata[
                (
                    filtered_metadata["cell_type"]
                    == self.cell_type_to_label.keys()[cell_type_id]
                    if isinstance(cell_type_id, int)
                    else cell_type_id
                )
            ]
        return filtered_metadata


class CustomDataset(Dataset):
    def __init__(self, data_dir):
        PIL.Image.init()
        supported_ext = PIL.Image.EXTENSION.keys() | {".npy"}

        self.images_dir = os.path.join(data_dir, "images")
        self.features_dir = os.path.join(data_dir, "vae-sd")

        # images
        self._image_fnames = {
            os.path.relpath(os.path.join(root, fname), start=self.images_dir)
            for root, _dirs, files in os.walk(self.images_dir)
            for fname in files
        }
        self.image_fnames = sorted(
            fname
            for fname in self._image_fnames
            if self._file_ext(fname) in supported_ext
        )
        # features
        self._feature_fnames = {
            os.path.relpath(os.path.join(root, fname), start=self.features_dir)
            for root, _dirs, files in os.walk(self.features_dir)
            for fname in files
        }
        self.feature_fnames = sorted(
            fname
            for fname in self._feature_fnames
            if self._file_ext(fname) in supported_ext
        )
        # labels
        fname = "dataset.json"
        with open(os.path.join(self.features_dir, fname), "rb") as f:
            labels = json.load(f)["labels"]
        labels = dict(labels)
        labels = [labels[fname.replace("\\", "/")] for fname in self.feature_fnames]
        labels = np.array(labels)
        self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])

    def _file_ext(self, fname):
        return os.path.splitext(fname)[1].lower()

    def __len__(self):
        assert len(self.image_fnames) == len(
            self.feature_fnames
        ), "Number of feature files and label files should be same"
        return len(self.feature_fnames)

    def __getitem__(self, idx):
        image_fname = self.image_fnames[idx]
        feature_fname = self.feature_fnames[idx]
        image_ext = self._file_ext(image_fname)
        with open(os.path.join(self.images_dir, image_fname), "rb") as f:
            if image_ext == ".npy":
                image = np.load(f)
                image = image.reshape(-1, *image.shape[-2:])
            elif image_ext == ".png" and pyspng is not None:
                image = pyspng.load(f.read())
                image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
            else:
                image = np.array(PIL.Image.open(f))
                image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)

        features = np.load(os.path.join(self.features_dir, feature_fname))
        return (
            torch.from_numpy(image),
            torch.from_numpy(features),
            torch.tensor(self.labels[idx]),
        )


def get_feature_dir_info(root):
    files = glob.glob(os.path.join(root, "*.npy"))
    files_caption = glob.glob(os.path.join(root, "*_*.npy"))
    num_data = len(files) - len(files_caption)
    n_captions = {k: 0 for k in range(num_data)}
    for f in files_caption:
        name = os.path.split(f)[-1]
        k1, k2 = os.path.splitext(name)[0].split("_")
        n_captions[int(k1)] += 1
    return num_data, n_captions


class DatasetFactory(object):

    def __init__(self):
        self.train = None
        self.test = None

    def get_split(self, split, labeled=False):
        if split == "train":
            dataset = self.train
        elif split == "test":
            dataset = self.test
        else:
            raise ValueError

        if self.has_label:
            return dataset  # if labeled else UnlabeledDataset(dataset)
        else:
            assert not labeled
            return dataset

    def unpreprocess(self, v):  # to B C H W and [0, 1]
        v = 0.5 * (v + 1.0)
        v.clamp_(0.0, 1.0)
        return v

    @property
    def has_label(self):
        return True

    @property
    def data_shape(self):
        raise NotImplementedError

    @property
    def data_dim(self):
        return int(np.prod(self.data_shape))

    @property
    def fid_stat(self):
        return None

    def sample_label(self, n_samples, device):
        raise NotImplementedError

    def label_prob(self, k):
        raise NotImplementedError


class MSCOCOFeatureDataset(Dataset):
    # the image features are got through sample
    def __init__(self, root):
        self.root = root
        self.num_data, self.n_captions = get_feature_dir_info(root)

    def __len__(self):
        return self.num_data

    def __getitem__(self, index):
        with open(os.path.join(self.root, f"{index}.png"), "rb") as f:
            x = np.array(PIL.Image.open(f))
            x = x.reshape(*x.shape[:2], -1).transpose(2, 0, 1)

        z = np.load(os.path.join(self.root, f"{index}.npy"))
        k = random.randint(0, self.n_captions[index] - 1)
        c = np.load(os.path.join(self.root, f"{index}_{k}.npy"))
        return x, z, c


class CFGDataset(Dataset):  # for classifier free guidance
    def __init__(self, dataset, p_uncond, empty_token):
        self.dataset = dataset
        self.p_uncond = p_uncond
        self.empty_token = empty_token

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        x, z, y = self.dataset[item]
        if random.random() < self.p_uncond:
            y = self.empty_token
        return x, z, y


class MSCOCO256Features(
    DatasetFactory
):  # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
    def __init__(self, path, cfg=True, p_uncond=0.1, mode="train"):
        super().__init__()
        print("Prepare dataset...")
        if mode == "val":
            self.test = MSCOCOFeatureDataset(os.path.join(path, "val"))
            assert len(self.test) == 40504
            self.empty_context = np.load(os.path.join(path, "empty_context.npy"))
        else:
            self.train = MSCOCOFeatureDataset(os.path.join(path, "train"))
            assert len(self.train) == 82783
            self.empty_context = np.load(os.path.join(path, "empty_context.npy"))

            if cfg:  # classifier free guidance
                assert p_uncond is not None
                print(
                    f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}"
                )
                self.train = CFGDataset(self.train, p_uncond, self.empty_context)

    @property
    def data_shape(self):
        return 4, 32, 32

    @property
    def fid_stat(self):
        return f"assets/fid_stats/fid_stats_mscoco256_val.npz"
