import os
from collections.abc import Callable

import torch

from .base import AnomalyDataset


CATEGORIES = (
    "bottle",
    "cable",
    "capsule",
    "carpet",
    "grid",
    "hazelnut",
    "leather",
    "metal_nut",
    "pill",
    "screw",
    "tile",
    "toothbrush",
    "transistor",
    "wood",
    "zipper",
)


class MVTecAD(AnomalyDataset):
    def __init__(self, category: str, root: str, split: str, transform: Callable | None = None):
        assert category in CATEGORIES
        assert split in ["train", "test"]
        data_root = os.path.join(root, "mvtec_anomaly_detection", category)
        super().__init__(root=data_root, split=split, transform=transform)

        self.category = category
        self.attr_names: list[str] = ["anomaly"]
        self.filename = []

        attr = []

        for anomaly in [v for v in os.scandir(os.path.join(data_root, split)) if v.is_dir()]:
            images = [
                os.path.join(split, anomaly.name, v.name) for v in os.scandir(anomaly.path)
                if v.name.endswith(".png")
            ]
            self.filename.extend(images)
            attr.extend([[0 if anomaly.name == "good" else 1]] * len(images))

        self.attr = torch.tensor(attr, dtype=torch.int64)


def mvtec(
    category: str,
    root: str = "./data",
    split: str = "train",
    transform: Callable | None = None,
):
    return MVTecAD(category=category, root=root, split=split, transform=transform)
