"""
Code taken from https://github.com/ryanchankh/cifar100coarse
"""

import numpy as np
from torchvision.datasets import CIFAR100


class CIFAR100Coarse(CIFAR100):
    # update classes
    classes = [
        ["beaver", "dolphin", "otter", "seal", "whale"],
        ["aquarium_fish", "flatfish", "ray", "shark", "trout"],
        ["orchid", "poppy", "rose", "sunflower", "tulip"],
        ["bottle", "bowl", "can", "cup", "plate"],
        ["apple", "mushroom", "orange", "pear", "sweet_pepper"],
        ["clock", "keyboard", "lamp", "telephone", "television"],
        ["bed", "chair", "couch", "table", "wardrobe"],
        ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
        ["bear", "leopard", "lion", "tiger", "wolf"],
        ["bridge", "castle", "house", "road", "skyscraper"],
        ["cloud", "forest", "mountain", "plain", "sea"],
        ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
        ["fox", "porcupine", "possum", "raccoon", "skunk"],
        ["crab", "lobster", "snail", "spider", "worm"],
        ["baby", "boy", "girl", "man", "woman"],
        ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
        ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
        ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"],
        ["bicycle", "bus", "motorcycle", "pickup_truck", "train"],
        ["lawn_mower", "rocket", "streetcar", "tank", "tractor"],
    ]

    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(CIFAR100Coarse, self).__init__(
            root, train, transform, target_transform, download
        )

        # update labels
        coarse_labels = np.array(
            [
                4,
                1,
                14,
                8,
                0,
                6,
                7,
                7,
                18,
                3,
                3,
                14,
                9,
                18,
                7,
                11,
                3,
                9,
                7,
                11,
                6,
                11,
                5,
                10,
                7,
                6,
                13,
                15,
                3,
                15,
                0,
                11,
                1,
                10,
                12,
                14,
                16,
                9,
                11,
                5,
                5,
                19,
                8,
                8,
                15,
                13,
                14,
                17,
                18,
                10,
                16,
                4,
                17,
                4,
                2,
                0,
                17,
                4,
                18,
                17,
                10,
                3,
                2,
                12,
                12,
                16,
                12,
                1,
                9,
                19,
                2,
                10,
                0,
                1,
                16,
                12,
                9,
                13,
                15,
                13,
                16,
                19,
                2,
                4,
                6,
                19,
                5,
                5,
                8,
                19,
                18,
                1,
                2,
                15,
                6,
                0,
                17,
                8,
                14,
                13,
            ]
        )
        self.targets = coarse_labels[self.targets]
