import copy
import sys
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

from .transforms import *
from data.dogs_vs_cats_data import DogsVsCatsData
from data.muffin_vs_chihuahua import MuffinVsChihuaha
from data.cifar100_coarse_data import CIFAR100Coarse
from data.imagenet30_data import ImageNet30
from data.pneumonia_data import PneumoniaData
from data.melanoma_data import MelanomaDataset

DATASETS = {
    "cifar10": torchvision.datasets.CIFAR10,
    "cifar100": CIFAR100Coarse,
    "imagenet30": ImageNet30,
    "dogs_vs_cats": DogsVsCatsData,
    "pneumonia": PneumoniaData,
    "muffin_vs_chihuahua": MuffinVsChihuaha,
    "melanoma": MelanomaDataset,
}


def get_datasets(cfg):
    if cfg.name in DATASETS:
        train_transform = get_transform(cfg)
        test_transform = get_transform(cfg, eval=True)
        print(cfg.name)
        trainset = DATASETS[cfg.name](
            root=cfg.data_path, train=True, transform=train_transform
        )
        testset = DATASETS[cfg.name](
            root=cfg.data_path, train=False, transform=test_transform
        )

        # Test for Rest vs One or One vs Rest
        if cfg.normal_class is not None and cfg.anomaly_class is not None:
            raise ValueError("Cannot specify both normal and anomaly class")
        elif cfg.normal_class is not None:
            # Get normal class labels
            start_label = cfg.normal_class
            n_normal_classes = cfg.n_normal_class
            normal_classes = [
                cls % cfg.get("num_dataset_classes", 10)
                for cls in range(start_label, start_label + n_normal_classes)
            ]
            # Remove anomalies form training set
            idx = np.any(
                np.concatenate(
                    [(np.array(trainset.targets) == x)[None] for x in normal_classes],
                    axis=0,
                ),
                axis=0,
            )
            # Binarize label for testset
            testset.targets = [int(t not in normal_classes) for t in testset.targets]
        elif cfg.anomaly_class is not None:
            idx = np.array(trainset.targets) != cfg.anomaly_class
            # Binarize label for testset
            testset.targets = [int(t == cfg.anomaly_class) for t in testset.targets]
        else:
            raise ValueError("Must specify either normal or anomaly class")

        trainset.data = trainset.data[idx]
        trainset.targets = np.array(trainset.targets)[idx]
        # Create trainset without augmentations
        trainset_no_aug = copy.deepcopy(trainset)
        trainset_no_aug.transform = train_transform

        return trainset, trainset_no_aug, testset
    else:
        print(f"Unsupported Dataset {cfg.name}")
        sys.exit()
