from pathlib import Path
from robustbench.data import load_cifar10c, load_cifar10, load_imagenet
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, ConcatDataset
import torchvision
import numpy as np
from robustness.tools.breeds_helpers import print_dataset_info
from robustness.tools.breeds_helpers import make_living17
from robustness.tools.breeds_helpers import ClassHierarchy
from robustness import datasets



def get_loaders(cfg, corruption_type, severity):
    if cfg.data.dataset_name == "cifar10":
        if not cfg.data.flip: 
            x_corr, y_corr = load_cifar10c(
                # 10000, severity, cfg.user.root_dir, False, [corruption_type]
                10000, severity, './', False, [corruption_type]
            )
        elif cfg.data.flip:
            x_corr, y_corr = load_cifar10(10000,  './')
            y_corr = 9 - y_corr
            
        assert cfg.args.train_n <= 9000
        labels = {}
        num_classes = int(max(y_corr)) + 1
        for i in range(num_classes):
            labels[i] = [ind for ind, n in enumerate(y_corr) if n == i]
        num_ex = cfg.args.train_n // num_classes
        tr_idxs = []
        val_idxs = []
        test_idxs = []
        for i in range(len(labels.keys())):
            np.random.shuffle(labels[i])
            tr_idxs.append(labels[i][:num_ex])
            val_idxs.append(labels[i][num_ex:num_ex+10])
            test_idxs.append(labels[i][num_ex+10:num_ex+100])
        tr_idxs = np.concatenate(tr_idxs)
        val_idxs = np.concatenate(val_idxs)
        test_idxs = np.concatenate(test_idxs)
        
        tr_dataset = TensorDataset(x_corr[tr_idxs], y_corr[tr_idxs])
        val_dataset = TensorDataset(x_corr[val_idxs], y_corr[val_idxs])
        te_dataset = TensorDataset(x_corr[test_idxs], y_corr[test_idxs])
    




    elif cfg.data.dataset_name == "imagenet-c":
        data_root = Path(cfg.data.root_dir)
        image_dir = data_root / "ImageNet-C" / corruption_type / str(severity)
        dataset = ImageFolder(image_dir, transform=transforms.ToTensor())
        indices = list(range(len(dataset.imgs))) #50k examples --> 50 per class
        assert cfg.args.train_n <= 20000
        labels = {}
        y_corr = dataset.targets
        for i in range(max(y_corr)+1):
            labels[i] = [ind for ind, n in enumerate(y_corr) if n == i] 
        num_ex = cfg.args.train_n // (max(y_corr)+1)
        tr_idxs = []
        val_idxs = []
        test_idxs = []
        for i in range(len(labels.keys())):
            np.random.shuffle(labels[i])
            tr_idxs.append(labels[i][:num_ex])
            val_idxs.append(labels[i][num_ex:num_ex+10])
            test_idxs.append(labels[i][num_ex+10:num_ex+20])
        tr_idxs = np.concatenate(tr_idxs)
        val_idxs = np.concatenate(val_idxs)
        test_idxs = np.concatenate(test_idxs)
        tr_dataset = Subset(dataset, tr_idxs)
        val_dataset = Subset(dataset, val_idxs)
        te_dataset = Subset(dataset, test_idxs)

    elif cfg.data.dataset_name == "living17":
        data_dir = './imagenet/'
        info_dir = './imagenet_class_hierarchy/modified'
        batch_size = cfg.data.batch_size
        num_workers = 0
        hier = ClassHierarchy(info_dir)
        ret = make_living17(info_dir, split="rand")
        superclasses, subclass_split, label_map = ret

        # print_dataset_info(superclasses, subclass_split,label_map, hier.LEAF_NUM_TO_NAME)
        train_subclasses, test_subclasses = subclass_split

        dataset_source = datasets.CustomImageNet(data_dir, train_subclasses)
        loaders_source = dataset_source.make_loaders(num_workers, batch_size)
        train_loader_source, test_loader_source = loaders_source

        dataset = train_loader_source.dataset

        total_size = len(dataset)
        validation_split = 0.2
        validation_size = int(validation_split * total_size)
        train_size = total_size - validation_size
        tr_dataset, val_dataset = random_split(dataset, [train_size, validation_size])
        te_dataset = test_loader_source.dataset


        dataset_target = datasets.CustomImageNet(data_dir, test_subclasses)
        loaders_target = dataset_target.make_loaders(num_workers, batch_size)
        train_loader_target, test_loader_target = loaders_target

        train_dataset = train_loader_target.dataset
        y_corr = train_dataset.targets
        labels = {}
        num_classes = int(max(y_corr)) + 1
        for i in range(num_classes):
            labels[i] = [ind for ind, n in enumerate(y_corr) if n == i]
        num_ex = 50
        tr_idxs = []
        val_idxs = []
        for i in range(len(labels.keys())):
            np.random.shuffle(labels[i])
            tr_idxs.append(labels[i][:num_ex])
            val_idxs.append(labels[i][num_ex:num_ex+20])

 
        # assert cfg.args.train_n <= 20000

        test_dataset = test_loader_target.dataset
        y_corr = test_dataset.targets
        labels = {}
        num_classes = int(max(y_corr)) + 1
        for i in range(num_classes):
            labels[i] = [ind for ind, n in enumerate(y_corr) if n == i]
        num_ex = 20
        test_idxs = []
        for i in range(len(labels.keys())):
            np.random.shuffle(labels[i])
            test_idxs.append(labels[i][:num_ex])

        tr_idxs = np.concatenate(tr_idxs)
        val_idxs = np.concatenate(val_idxs)
        test_idxs = np.concatenate(test_idxs)
        tr_dataset = Subset(train_dataset, tr_idxs)
        val_dataset = Subset(train_dataset, val_idxs)
        te_dataset = Subset(test_dataset, test_idxs)

    

    return {'train':DataLoader(tr_dataset, batch_size= cfg.data.batch_size, shuffle=True), 'test':DataLoader(te_dataset, batch_size= cfg.data.batch_size), 'val':DataLoader(val_dataset, batch_size= cfg.data.batch_size)}
    # return DataLoader(tr_dataset, te_dataset, val_dataset)
    

    