import os
from itertools import product
from collections.abc import Callable
from typing import Any

import torch
from torch.utils.data import Subset
from torchvision.datasets import CelebA
from torchvision.datasets.vision import VisionDataset

from PIL import Image


__all__ = [
    "AnomalyDataset",
    "make_condition",
    "make_subset",
    "sample_index",
]


class AnomalyDataset(VisionDataset):
    def __init__(self, root: str, split: str, transform: Callable | None = None) -> None:
        super().__init__(root=root, transform=transform)
        self.split = split

        self.attr: torch.Tensor
        self.attr_names: list[str]
        self.filename: list[str]

    def load_image(self, index: int) -> Image.Image:
        with open(os.path.join(self.root, self.filename[index]), "rb") as f:
            image = Image.open(f)
            image.load()
            return image

    def __getitem__(self, index: int) -> tuple[Any, Any]:
        image = self.load_image(index)
        target = self.attr[index, :]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self) -> int:
        return len(self.filename)


def make_condition(attr_names: list[str], attribute_config: dict[int | str, bool]) -> Callable:
    attr_names_lower = [name.lower() for name in attr_names]
    new_attribute_config = {
        k if isinstance(k, int) else attr_names_lower.index(k.lower()): v
        for k, v in attribute_config.items()
    }
    assert len(attribute_config.keys()) == len(new_attribute_config.keys()), "Duplicate keys"

    return lambda labels: all(labels[i] if v else not labels[i] for i, v in new_attribute_config.items())


def make_subset(dataset: AnomalyDataset | CelebA, cond: Callable):
    return Subset(dataset, [i for i, labels in enumerate(dataset.attr) if cond(labels)])


def sample_index(
    attrs: torch.Tensor,  # [N] or [N, num_attrs]
    n_samples: int,
    per_attr: bool = True,  # if False, sample from all attributes (if n_samples is int)
    seed: int = 0,
    sort: bool = True,
) -> torch.Tensor:
    rng = torch.Generator().manual_seed(seed)

    if per_attr:
        if attrs.ndim == 1:
            attrs = attrs.unsqueeze(dim=1)

        num_attrs = attrs.size(1)
        values_per_attr = [attrs[:, i].unique().tolist() for i in range(num_attrs)]
        indicies_list = []

        for values in product(*values_per_attr):
            mask = torch.ones_like(attrs[:, 0], dtype=torch.bool)
            for i, v in enumerate(values):
                mask &= attrs[:, i] == v

            candidates = mask.nonzero(as_tuple=False).flatten()
            _indices = torch.randperm(candidates.size(0), generator=rng)[:n_samples]
            indicies_list.append(candidates[_indices])

        indices = torch.cat(indicies_list)

    else:
        indices = torch.randperm(attrs.size(0), generator=rng)[:n_samples]

    indices = indices.to(attrs.device)
    if sort:
        indices = indices.sort().values

    return indices
