from collections.abc import Callable

import numpy as np

import torch
from torch.utils.data import Subset
from torchvision.datasets import MNIST

from PIL import Image

from .base import AnomalyDataset


def _train_valid_split(dataset, num_train, num_valid, seed=42):
    rng = torch.Generator().manual_seed(seed)
    train_idxs, valid_idxs = [], []

    for i in range(10):
        class_idxs = (dataset.targets == i).nonzero(as_tuple=False).flatten().tolist()
        perm_class_idxs = torch.randperm(len(class_idxs), generator=rng).tolist()

        train_idxs.extend([class_idxs[idx] for idx in perm_class_idxs[:num_train]])
        valid_idxs.extend([class_idxs[idx] for idx in perm_class_idxs[num_train:num_train+num_valid]])

    train_dataset = Subset(dataset, sorted(train_idxs))
    valid_dataset = Subset(dataset, sorted(valid_idxs))
    return train_dataset, valid_dataset


def coloring_inplace(img: np.ndarray, color: str) -> np.ndarray:
    if color == "red":
        img[:, :, [1, 2]] = 0
    elif color == "green":
        img[:, :, [0, 2]] = 0
    elif color == "blue":
        img[:, :, [0, 1]] = 0
    else:
        raise ValueError(f"Color {color} is not supported.")
    return img


class ColorMNIST(AnomalyDataset):
    config = {
        "label": {
            0: True,
            1: True,
            2: True,
            3: True,
            4: True,
            5: False,
            6: False,
            7: False,
            8: False,
            9: False,
        },
        "color": {
            "red": True,
            "green": False,
            "blue": False,
        },
    }

    def __init__(self, root: str, split: str, transform: Callable | None = None, seed: int = 42):
        assert split in ["train", "valid", "test"]
        super().__init__(root=root, split=split, transform=transform)

        self.attr_names: list[str] = list(self.config.keys())
        self.dataset = []
        attr = []

        if split == "train":
            num_of_data = [4500 for _ in range(10)]
        elif split == "valid":
            num_of_data = [900 for _ in range(10)]
        else:
            num_of_data = [870 for _ in range(10)]

        for i in range(10):
            if i not in self.config["label"] or self.config["label"][i] is None:
                num_of_data[i] = 0

        dataset = MNIST(root=root, train=False if split == "test" else True, download=True)

        if split in ["train", "valid"]:
            train_dataset, valid_dataset = _train_valid_split(dataset, 4500, 900, seed=seed)
            dataset = train_dataset if split == "train" else valid_dataset

        cnt = [0 for _ in range(10)]
        for img, label_digit in dataset:
            if cnt[label_digit] >= num_of_data[label_digit]:
                continue

            _attr = [0] * len(self.attr_names)
            _attr[0] = self.config["label"][label_digit]  # True if label_digit < 5 else False

            img = np.pad(np.array(img), ((28, 28), (28, 28)), "constant", constant_values=0)
            img = np.stack([img] * 3, axis=-1)
            color = list(self.config["color"].keys())[cnt[label_digit] % len(self.config["color"])]
            coloring_inplace(img, color)
            _attr[1] = self.config["color"][color]

            img = Image.fromarray(img)
            self.dataset.append(img)
            attr.append(_attr)
            cnt[label_digit] += 1

        self.attr = torch.tensor(attr, dtype=torch.int64)

    def load_image(self, index: int) -> Image.Image:
        image = self.dataset[index]
        return image

    def __len__(self) -> int:
        return len(self.dataset)


def color_mnist(
    root: str = "./data",
    split: str = "train",
    transform: Callable | None = None,
):
    return ColorMNIST(root=root, split=split, transform=transform)
