#from wilds.common.data_loaders import get_train_loader
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
import torchvision.transforms as transforms
import torch
import glob
import pdb
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler
from wilds.common.utils import get_counts, split_into_groups
from wilds.datasets.wilds_dataset import WILDSSubset

num_classes_dict = {'iwildcam': 182, 'camelyon17': 2, 'rxrx1': 1139, 'fmow': 62, 'ogb-molpcba': 128, 'waterbirds': 2} #* Fmow: 200
resize_dict = {'iwildcam': (448,448), 'camelyon17': (96,96), 'rxrx1': (256,256), 'fmow': (224,224), 'waterbirds': (224,224)}


class WildsDataset():
	def __init__(self, data_name, args, test_type='test'):
		"""
		Dataset class of the WILDS benchmark datasets.

		Args:
			data_name ([string]): name of the dataset, among {"iwildcam",
				"camelyon17", "rxrx1"} 
			args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
			test_type (str, optional): test data type. it can be In-distribution test('idtest') or
				Out-Of-Distribution test('test'). Defaults to 'test'.
			idx (np.array[Int]): If given, only subset of training dataset will be used.
		"""
		self.data_name = data_name
		self.dataset = get_dataset(dataset=data_name, download=True)
		resize = resize_dict[data_name]

		# Get the training set
		self.train_data = self.dataset.get_subset(
			"train",
			transform=transforms.Compose(
				[transforms.Resize(resize), transforms.ToTensor()]
			),
		)

		self.test_data = self.dataset.get_subset(
			test_type,
			transform=transforms.Compose(
				[transforms.Resize(resize), transforms.ToTensor()]
			),
		)
		self.N_training = len(self.train_data)
		self.N_test = len(self.test_data)
		self.target_dim = num_classes_dict[data_name] 
		# Prepare the standard data loader
	
	def get_loader(self, args, shuffle_train=True):
		"""
		return the dataloader.
		
		Args:
			args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.

		Returns:
			([iter],[iter]): (training set dataloader, test set dataloader) 
		"""
		train_loader = get_train_loader("standard", self.train_data, batch_size=args.batch_size, shuffle=shuffle_train)
		test_loader = get_eval_loader("standard", self.test_data, batch_size=args.batch_size)
		return train_loader, test_loader

	def get_hard_loader(self, args, idx, shuffle=True):
		"""
		return the dataloader, only including hard examples.

		Args:
			args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
			idx ([np.array[Int]]): index of the hard samples 

		Returns:
			([iter]): hard training set dataloader 
		"""
		resize = resize_dict[self.data_name]
		self.train_data_hard = WILDSSubset(self.dataset, idx,
			transform=transforms.Compose(
				[transforms.Resize(resize), transforms.ToTensor()]
			),
		)	
		train_loader_hard = get_train_loader("standard", self.train_data_hard, batch_size=args.hard_batch_size, shuffle=shuffle)
		return train_loader_hard

	def add_training_data(self, idx, times=1):
		hard_idx = []
		for i in range(times):
			hard_idx += idx
		self.train_data.indices = np.concatenate((self.train_data.indices, hard_idx))
        

class WildsDatasetMolPCBA():
	def __init__(self, args, test_type='test'):
		"""
		Dataset class of the WILDS benchmark datasets.

		Args:
			args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
			test_type (str, optional): test data type. it can be In-distribution test('idtest') or
				Out-Of-Distribution test('test'). Defaults to 'test'.
		"""
		self.dataset = get_dataset(dataset='ogb-molpcba', download=True)
		# Get the training set
		self.train_data = self.dataset.get_subset(
			"train"
		)
		
		self.test_data = self.dataset.get_subset(
			test_type
		)
		self.N_training = len(self.train_data)
		self.N_test = len(self.test_data)
		self.target_dim = num_classes_dict['ogb-molpcba'] 
		# Prepare the standard data loader
	
	def get_loader(self, args, shuffle_train=True):
		"""
		return the dataloader.
		
		Args:
			args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.

		Returns:
			([iter],[iter]): (training set dataloader, test set dataloader) 
		"""
		train_loader = get_train_loader("standard", self.train_data, batch_size=args.batch_size, shuffle=shuffle_train)
		test_loader = get_eval_loader("standard", self.test_data, batch_size=args.batch_size)
		return train_loader, test_loader

	def get_hard_loader(self, args, idx, shuffle=True):
		"""
		return the dataloader, only including hard examples.

		Args:
			args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
			idx ([np.array[Int]]): index of the hard samples 

		Returns:
			([iter]): hard training set dataloader 
		"""
		resize = resize_dict[self.data_name]
		self.train_data_hard = WILDSSubset(self.dataset, idx,
		)	
		train_loader_hard = get_train_loader("standard", self.train_data_hard, batch_size=args.hard_batch_size, shuffle=shuffle)
		return train_loader_hard


def get_cifar10_dataloader(args):
	transform_train = transforms.Compose([
		transforms.RandomCrop(32, padding=4),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
	])

	transform_test = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
	])

	trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
	trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)

	testset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_test)
	testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
	return trainloader, testloader


class CIFARC_Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        if transform != None:
            self.images = []
            for img in images:
                self.images.append(transform(img))
            self.images = torch.stack(self.images)
        else:
            self.images = images
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, key):
        if type(key) == slice:
            return CIFARC_Dataset(self.images[key], self.labels[key])
        return self.images[key], self.labels[key]


def get_cifarC_dataloader(directory, data_name='cifar10', batch_size=512):
	transform_test = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
	])
	dataset_name = 'CIFAR-100-C' if data_name.lower() == 'cifar100' else 'CIFAR-10-C'
	#* CIFAR-C Data loaders
	shift_files = glob.glob(f'{directory}*.npy')
	shift_files.remove(f'{directory}labels.npy')
	label = torch.from_numpy(np.load(f'{directory}/labels.npy'))
	loaders = {'1':[], '2':[], '3':[], '4':[], '5':[]}

	for file in shift_files:
		print(f"Make loaders for {file}")
		image = np.load(file)#).transpose(3,1).transpose(3,2) # Make NCHW
		testset = CIFARC_Dataset(image, label, transform=transform_test)
		for i in range(5):
			print(f"Shift intensity : [{i+1}]")
			testloader = torch.utils.data.DataLoader(testset[i*10000: (i+1)*10000], batch_size=batch_size
				, shuffle=False)
			loaders[f'{i+1}'].append(testloader)
	return loaders


##* from https://github.com/p-lambda/wilds.
def get_train_loader(loader, dataset, batch_size,
        uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, shuffle=True, **loader_kwargs):
    """
    Constructs and returns the data loader for training.
    Args:
        - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders,
                        which first samples groups and then samples a fixed number of examples belonging
                        to each group.
        - dataset (WILDSDataset or WILDSSubset): Data
        - batch_size (int): Batch size
        - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according
                                              to the natural data distribution.
                                              Setting to None applies the defaults for each type of loaders.
                                              For standard loaders, the default is False. For group loaders,
                                              the default is True.
        - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True
        - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders.
        - n_groups_per_batch (int): Number of groups to sample in each minibatch for group loaders.
        - loader_kwargs: kwargs passed into torch DataLoader initialization.
    Output:
        - data loader (DataLoader): Data loader.
    """
    if loader == 'standard':
        if uniform_over_groups is None or not uniform_over_groups:
            return DataLoader(
                dataset,
                shuffle=shuffle, # Shuffle training dataset
                sampler=None,
                collate_fn=dataset.collate,
                batch_size=batch_size,
                **loader_kwargs)
        else:
            assert grouper is not None
            groups, group_counts = grouper.metadata_to_group(
                dataset.metadata_array,
                return_counts=True)
            group_weights = 1 / group_counts
            weights = group_weights[groups]

            # Replacement needs to be set to True, otherwise we'll run out of minority samples
            sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
            return DataLoader(
                dataset,
                shuffle=False, # The WeightedRandomSampler already shuffles
                sampler=sampler,
                collate_fn=dataset.collate,
                batch_size=batch_size,
                **loader_kwargs)

    elif loader == 'group':
        if uniform_over_groups is None:
            uniform_over_groups = True
        assert grouper is not None
        assert n_groups_per_batch is not None
        if n_groups_per_batch > grouper.n_groups:
            raise ValueError(f'n_groups_per_batch was set to {n_groups_per_batch} but there are only {grouper.n_groups} groups specified.')

        group_ids = grouper.metadata_to_group(dataset.metadata_array)
        batch_sampler = GroupSampler(
            group_ids=group_ids,
            batch_size=batch_size,
            n_groups_per_batch=n_groups_per_batch,
            uniform_over_groups=uniform_over_groups,
            distinct_groups=distinct_groups)

        return DataLoader(dataset,
              shuffle=None,
              sampler=None,
              collate_fn=dataset.collate,
              batch_sampler=batch_sampler,
              drop_last=False,
              **loader_kwargs)