import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.utils.data import Dataset, default_collate

import smlm


class ImagesAndActivationsDataset(Dataset):
    """
    Iterates over a 3D images and their respective activations.
    """

    def __init__(self, y: Tensor, window: int, x: pd.Series) -> None:
        """
        y: 3D Tensor [N, W, H]
        window: window size
        x: 2D Tensor [N, 6] with columns [frame, x, y, z, n, s]
        """
        if y.ndim != 3:
            raise ValueError("y must be 3D tensor [N, W, H]")
        if window < 1:
            raise ValueError("window size must be >= 1")
        if window % 2 == 0:
            raise ValueError("window size must be odd")
        if x.shape[1] != 6:
            raise ValueError(
                "x must be a 2D tensor with columns [frame, x, y, z, n, s]"
            )
        self.y = y
        self.window = window
        self.pad = window // 2
        self.x = x
        self.max_n_acts = x.groupby("frame").size().max()

    def collate_fn(self, batch):
        batch_dict = {key: [d[key] for d in batch] for key in batch[0]}
        y, x, s = batch_dict["y"], batch_dict["x"], batch_dict["s"]
        y = default_collate(y)
        x = smlm.utils.nested.pad_sequence(
            x,
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        s = smlm.utils.nested.pad_sequence(
            s,
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        return {"y": y, "x": x, "s": s}

    def __len__(self) -> int:
        return self.y.size(0) - self.window + 1

    def __getitem__(self, index: int) -> torch.Tensor:
        if index < 0 or index >= self.__len__():
            raise StopIteration()

        y = self.y[index : index + self.window]

        mask = self.x["frame"] == index + self.pad
        x = self.x[mask][["x", "y", "z", "photons"]]
        x = torch.from_numpy(x.to_numpy().astype(np.float32))

        s = self.x[mask][["significant"]]
        s = torch.from_numpy(s.to_numpy().astype(np.bool))
        s.squeeze_(-1)

        return {"y": y, "x": x, "s": s}
