import copy
import sys

import numpy as np
import torchvision

from data.br35h import BR35HDataset
from data.cifar100_coarse_data import CIFAR100Coarse
from data.dogs_vs_cats_data import DogsVsCatsData
from data.imagenet30_data import ImageNet30
from data.kvasir import KvasirDataset
from data.medmnist_data import BreastMNISTDataset, OctMNISTDataset
from data.melanoma_data import MelanomaDataset
from data.muffin_vs_chihuahua import MuffinVsChihuaha
from data.pneumonia_data import PneumoniaData

from .transforms import get_transform

DATASETS = {
    "cifar10": torchvision.datasets.CIFAR10,
    "cifar100": CIFAR100Coarse,
    "imagenet30": ImageNet30,
    "dogs_vs_cats": DogsVsCatsData,
    "pneumonia": PneumoniaData,
    "muffin_vs_chihuahua": MuffinVsChihuaha,
    "melanoma": MelanomaDataset,
    "br35h": BR35HDataset,
    "breast_mnist": BreastMNISTDataset,
    "octmnist": OctMNISTDataset,
    "kvasir": KvasirDataset,
}


def get_datasets(cfg):
    if cfg.name in DATASETS:
        train_transform = get_transform(cfg)
        test_transform = get_transform(cfg, eval=True)
        print(cfg.name)
        trainset = DATASETS[cfg.name](
            root=cfg.data_path,
            train=True,
            transform=train_transform,
        )
        testset = DATASETS[cfg.name](
            root=cfg.data_path,
            train=False,
            transform=test_transform,
        )

        # Test for Rest vs One or One vs Rest
        if cfg.normal_class is not None and cfg.anomaly_class is not None:
            raise ValueError("Cannot specify both normal and anomaly class")
        elif cfg.normal_class is not None:
            # Get normal class labels
            start_label = cfg.normal_class
            n_normal_classes = cfg.n_normal_class
            normal_classes = [
                cls % cfg.get("num_dataset_classes", 10)
                for cls in range(start_label, start_label + n_normal_classes)
            ]
            # Remove anomalies form training set
            idx = np.any(
                np.concatenate(
                    [(np.array(trainset.targets) == x)[None] for x in normal_classes],
                    axis=0,
                ),
                axis=0,
            )
            # Binarize labels for testset
            testset.targets = [int(t not in normal_classes) for t in testset.targets]
        elif cfg.anomaly_class is not None:
            idx = np.array(trainset.targets) != cfg.anomaly_class
            # Binarize label for testset
            testset.targets = [int(t == cfg.anomaly_class) for t in testset.targets]
        else:
            raise ValueError("Must specify either normal or anomaly class")
        train_data = copy.deepcopy(trainset.data)
        trainset.data = trainset.data[idx]
        trainset.targets = np.zeros_like(np.array(trainset.targets)[idx])
        if cfg.get("contamination_ratio", 0) != 0:
            # Add anomalies to training set
            ano_idx = np.argwhere(~idx).flatten()
            np.random.shuffle(ano_idx)
            n_anomalies = int(
                (cfg.contamination_ratio * len(trainset))
                / (1 - cfg.contamination_ratio)
            )
            trainset.data = np.concatenate(
                [trainset.data, train_data[ano_idx[:n_anomalies]]], axis=0
            )
            trainset.targets = np.concatenate(
                [trainset.targets, np.ones(n_anomalies)], axis=0
            )
        # Create trainset without augmentations
        trainset_no_aug = copy.deepcopy(trainset)
        trainset_no_aug.transform = train_transform

        return trainset, trainset_no_aug, testset
    else:
        print(f"Unsupported Dataset {cfg.name}")
        sys.exit()
