from typing import List, Tuple, Dict

from avalanche.benchmarks.datasets.imagenet_data import (
    IMAGENET_TORCHVISION_WNID_TO_IDX,
    IMAGENET_TORCHVISION_CLASSES,
)

MINI_IMAGENET_WNIDS: List[str] = [
    "n02110341",
    "n01930112",
    "n04509417",
    "n04067472",
    "n04515003",
    "n02120079",
    "n03924679",
    "n02687172",
    "n03075370",
    "n07747607",
    "n09246464",
    "n02457408",
    "n04418357",
    "n03535780",
    "n04435653",
    "n03207743",
    "n04251144",
    "n03062245",
    "n02174001",
    "n07613480",
    "n03998194",
    "n02074367",
    "n04146614",
    "n04243546",
    "n03854065",
    "n03838899",
    "n02871525",
    "n03544143",
    "n02108089",
    "n13133613",
    "n03676483",
    "n03337140",
    "n03272010",
    "n01770081",
    "n09256479",
    "n02091244",
    "n02116738",
    "n04275548",
    "n03773504",
    "n02606052",
    "n03146219",
    "n04149813",
    "n07697537",
    "n02823428",
    "n02089867",
    "n03017168",
    "n01704323",
    "n01532829",
    "n03047690",
    "n03775546",
    "n01843383",
    "n02971356",
    "n13054560",
    "n02108551",
    "n02101006",
    "n03417042",
    "n04612504",
    "n01558993",
    "n04522168",
    "n02795169",
    "n06794110",
    "n01855672",
    "n04258138",
    "n02110063",
    "n07584110",
    "n02091831",
    "n03584254",
    "n03888605",
    "n02113712",
    "n03980874",
    "n02219486",
    "n02138441",
    "n02165456",
    "n02108915",
    "n03770439",
    "n01981276",
    "n03220513",
    "n02099601",
    "n02747177",
    "n01749939",
    "n03476684",
    "n02105505",
    "n02950826",
    "n04389033",
    "n03347037",
    "n02966193",
    "n03127925",
    "n03400231",
    "n04296562",
    "n03527444",
    "n04443257",
    "n02443484",
    "n02114548",
    "n04604644",
    "n01910747",
    "n04596742",
    "n02111277",
    "n03908618",
    "n02129165",
    "n02981792",
]

MINI_IMAGENET_WNID_TO_IDX: Dict[str, int] = {
    cls_name: i for i, cls_name in enumerate(MINI_IMAGENET_WNIDS)
}

MINI_IMAGENET_CLASSES: List[Tuple[str, ...]] = []
for wnid in MINI_IMAGENET_WNIDS:
    imagenet_idx = IMAGENET_TORCHVISION_WNID_TO_IDX[wnid]
    MINI_IMAGENET_CLASSES.append(IMAGENET_TORCHVISION_CLASSES[imagenet_idx])

MINI_IMAGENET_CLASS_TO_IDX: Dict[str, int] = {
    cls: idx for idx, clss in enumerate(MINI_IMAGENET_CLASSES) for cls in clss
}

__all__ = [
    "MINI_IMAGENET_WNIDS",
    "MINI_IMAGENET_WNID_TO_IDX",
    "MINI_IMAGENET_CLASSES",
    "MINI_IMAGENET_CLASS_TO_IDX",
]
