#from wilds.common.data_loaders import get_train_loader
from wilds.common.grouper import CombinatorialGrouper
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
import torchvision.transforms as transforms
import torch
import glob
import pdb
import copy
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler
from wilds.common.utils import get_counts, split_into_groups
from wilds.datasets.wilds_dataset import WILDSSubset, WILDSDataset
from transformers import BertTokenizerFast
from tqdm import tqdm

num_classes_dict = {'iwildcam': 182, 'camelyon17': 2, 'rxrx1': 1139, 'fmow': 62, 'ogb-molpcba': 128, 'waterbirds': 2, 'povertymap': 1, 'celebA': 2,
                    'multinli': 2, 'civilcomments': 2}  # * Fmow: 200
resize_dict = {'iwildcam': (448, 448), 'camelyon17': (96, 96), 'rxrx1': (
    256, 256), 'fmow': (224, 224), 'waterbirds': (224, 224), 'celebA': (178, 178)}

_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
loader_kwargs = {'num_workers': 8, 'pin_memory': True}


def get_subset_with_idx(self, split, frac=1.0, transform=None):
    """
    Args:
        - split (str): Split identifier, e.g., 'train', 'val', 'test'.
                        Must be in self.split_dict.
        - frac (float): What fraction of the split to randomly sample.
                        Used for fast development on a small dataset.
        - transform (function): Any data transformations to be applied to the input x.
    Output:
        - subset (WILDSSubset): A (potentially subsampled) subset of the WILDSDataset.
    """
    if split not in self.split_dict:
        raise ValueError(f"Split {split} not found in dataset's split_dict.")

    split_mask = self.split_array == self.split_dict[split]
    split_idx = np.where(split_mask)[0]

    if frac < 1.0:
        # Randomly sample a fraction of the split
        num_to_retain = int(np.round(float(len(split_idx)) * frac))
        split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])

    return WILDSSubsetWithIdx(self, split_idx, transform)


def metadata_to_int(metadatas, indices=None):
    if indices is None:
        indices = range(len(metadatas))
    if type(metadatas) == torch.Tensor:
        metadatas = metadatas.numpy()
    types = np.unique(metadatas, axis=0)
    group_counts = dict(zip(range(len(types)), [0 for i in range(len(types))]))
    metadata_ids = []
    for m in tqdm(metadatas[indices]):
        idx = 0
        for t in types:
            if np.all(m == t):
                group_counts[idx] += 1
                metadata_ids.append(idx)
                break
            idx += 1
    return torch.Tensor(metadata_ids), list(group_counts.values())


def __getitem_id__(self, idx):
    x = self.get_input(idx)
    y = self.y_array[idx]
    metadata = self.metadata_id_array[idx]
    return x, y, metadata


# Add new method: get_subset_with_idx
WILDSDataset.get_subset_with_idx = get_subset_with_idx


class WildsDataset():
    def __init__(self, data_name, args, test_type='test', resize_default=False, with_idx=False, group_dro=False):
        """
        Dataset class of the WILDS benchmark datasets.

        Args:
            data_name ([string]): name of the dataset, among {"iwildcam",
                "camelyon17", "rxrx1"} 
            args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
            test_type (str, optional): test data type. it can be In-distribution test('idtest') or
                Out-Of-Distribution test('test'). Defaults to 'test'.
            idx (np.array[Int]): If given, only subset of training dataset will be used.
        """
        self.data_name = data_name
        self.dataset = get_dataset(
            dataset=data_name, download=True, root_dir=args.data_path)
        if resize_default:
            resize = (64, 64)
        else:
            try:
                resize = resize_dict[data_name]
            except:
                print(
                    f"There is no resize for {data_name}, it uses the default size (64,64).")
                resize = (64, 64)

        if data_name in ['civilcomments']:
            self.tokenizer = BertTokenizerFast.from_pretrained(
                'bert-base-uncased')
            self.transform = self.transform_bert
        else:
            self.transform = transforms.Compose([
                transforms.Resize(resize),
                transforms.ToTensor(),
                transforms.Normalize(
                    _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
            ])

        subset_fn = self.dataset.get_subset if not with_idx else self.dataset.get_subset_with_idx
        # Get the training set
        self.train_data = subset_fn(
            "train",
            transform=self.transform
        )
        # try:
        self.val_data = self.dataset.get_subset(
            'val',
            transform=self.transform
        )
        # except:
        #    print(f"{self.data_name} dataset does not include the validation dataset.")

        self.test_data = self.dataset.get_subset(
            test_type,
            transform=self.transform
        )
        self.N_training = len(self.train_data)
        self.N_test = len(self.test_data)
        self.target_dim = num_classes_dict[data_name]
        self.noisy_idx = None
        self.grouper = None
        self.raw_metadata = copy.deepcopy(
            self.train_data.dataset.metadata_array)
        if group_dro:
            self.train_data.dataset.metadata_id_array, self.group_counts = metadata_to_int(
                self.raw_metadata)
            self.id_metadata = copy.deepcopy(
                self.train_data.dataset.metadata_id_array)
        # Prepare the standard data loader

    def inject_label_noise(self, p=0.1):
        where_label_noise = torch.bernoulli(
            torch.ones(len(self.train_data.dataset))*p).bool()
        is_train = torch.bincount(torch.from_numpy(
            self.train_data.indices), minlength=len(self.train_data.dataset))
        where_label_noise = torch.logical_and(
            is_train.bool(), where_label_noise)

        rand_label = torch.randint(
            self.target_dim, (1, len(self.train_data.dataset))).flatten()
        changed_label = torch.where(
            where_label_noise, rand_label, self.train_data.dataset.y_array)
        self.train_data.dataset._y_array = changed_label

    def get_loader(self, args, shuffle_train=True, train_sampler=None, loader='standard', uniform_over_group=False):
        """
        return the dataloader.

        Args:
            args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.

        Returns:
            ([iter],[iter]): (training set dataloader, test set dataloader) 
        """
        if uniform_over_group:
            WILDSDataset.__getitem__ = __getitem_id__
            self.grouper = CombinatorialGrouper(
                self.dataset, self.dataset.metadata_fields)

        train_loader = get_train_loader(
            loader, self.train_data, batch_size=args.batch_size, sampler=train_sampler, shuffle=shuffle_train, grouper=self.grouper, uniform_over_groups=uniform_over_group, **loader_kwargs)
        test_loader = get_eval_loader(
            "standard", self.test_data, batch_size=args.batch_size, **loader_kwargs)
        # try:
        val_loader = get_eval_loader(
            "standard", self.val_data, batch_size=args.batch_size, **loader_kwargs)
        # except:
        #    val_loader = None
        return train_loader, val_loader, test_loader

    def convert_to_raw_metadata(self):
        self.train_data.dataset.metadata_id_array = self.raw_metadata

    def convert_to_id_metadata(self):
        self.train_data.dataset.metadata_id_array = self.id_metadata

    def add_training_data(self, idx, times=1):
        hard_idx = []

        if type(idx) == np.ndarray:
            idx = list(idx)

        for i in range(times):
            hard_idx += idx
        self.train_data.indices = np.concatenate(
            (self.train_data.indices, hard_idx))

    def transform_bert(self, text):
        tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=300,
            return_tensors="pt",
        )
        x = torch.stack(
            (
                tokens["input_ids"],
                tokens["attention_mask"],
                tokens["token_type_ids"],
            ),
            dim=2,
        )
        x = torch.squeeze(x, dim=0)  # First shape dim is always 1
        return x


class WildsDatasetMolPCBA():
    def __init__(self, args, test_type='test'):
        """
        Dataset class of the WILDS benchmark datasets.

        Args:
            args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
            test_type (str, optional): test data type. it can be In-distribution test('idtest') or
                Out-Of-Distribution test('test'). Defaults to 'test'.
        """
        self.dataset = get_dataset(dataset='ogb-molpcba', download=True)
        # Get the training set
        self.train_data = self.dataset.get_subset(
            "train"
        )

        self.test_data = self.dataset.get_subset(
            test_type
        )
        self.N_training = len(self.train_data)
        self.N_test = len(self.test_data)
        self.target_dim = num_classes_dict['ogb-molpcba']
        # Prepare the standard data loader

    def get_loader(self, args, shuffle_train=True):
        """
        return the dataloader.

        Args:
            args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.

        Returns:
            ([iter],[iter]): (training set dataloader, test set dataloader) 
        """
        train_loader = get_train_loader(
            "standard", self.train_data, batch_size=args.batch_size, shuffle=shuffle_train)
        test_loader = get_eval_loader(
            "standard", self.test_data, batch_size=args.batch_size)
        return train_loader, test_loader

    def get_hard_loader(self, args, idx, shuffle=True):
        """
        return the dataloader, only including hard examples.

        Args:
            args ([ArgumentParser.parse_argument()]): parsed arguments of argparser.
            idx ([np.array[Int]]): index of the hard samples 

        Returns:
            ([iter]): hard training set dataloader 
        """
        resize = resize_dict[self.data_name]
        self.train_data_hard = WILDSSubset(self.dataset, idx,
                                           )
        train_loader_hard = get_train_loader(
            "standard", self.train_data_hard, batch_size=args.hard_batch_size, shuffle=shuffle)
        return train_loader_hard

    def add_training_data(self, idx, times=1):
        hard_idx = []
        for i in range(times):
            hard_idx += idx
        self.train_data.indices = np.concatenate(
            (self.train_data.indices, hard_idx))


class WILDSSubsetWithIdx(WILDSSubset):
    """
    Child class of the WILDSSubset: its __getitem__ method returns the sample indices additionally.
    LfF (Learning from Failure) method uses this class to calculate the moving average of the loss
    *per samples*.
    """

    def __init__(self, dataset, indices, transform, do_transform_y=False):
        super().__init__(dataset, indices, transform, do_transform_y)

    def __getitem__(self, idx):
        x, y, metadata = self.dataset[self.indices[idx]]
        if self.transform is not None:
            if self.do_transform_y:
                x, y = self.transform(x, y)
            else:
                x = self.transform(x)
        return idx, x, y, metadata


def get_cifar10_dataloader(args):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=args.data_path, train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)

    testset = torchvision.datasets.CIFAR10(
        root=args.data_path, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=0)
    return trainloader, testloader


class CIFARC_Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        if transform != None:
            self.images = []
            for img in images:
                self.images.append(transform(img))
            self.images = torch.stack(self.images)
        else:
            self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, key):
        if type(key) == slice:
            return CIFARC_Dataset(self.images[key], self.labels[key])
        return self.images[key], self.labels[key]


def get_cifarC_dataloader(directory, data_name='cifar10', batch_size=512):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    dataset_name = 'CIFAR-100-C' if data_name.lower() == 'cifar100' else 'CIFAR-10-C'
    # * CIFAR-C Data loaders
    shift_files = glob.glob(f'{directory}*.npy')
    shift_files.remove(f'{directory}labels.npy')
    label = torch.from_numpy(np.load(f'{directory}/labels.npy'))
    loaders = {'1': [], '2': [], '3': [], '4': [], '5': []}

    for file in shift_files:
        print(f"Make loaders for {file}")
        image = np.load(file)  # ).transpose(3,1).transpose(3,2) # Make NCHW
        testset = CIFARC_Dataset(image, label, transform=transform_test)
        for i in range(5):
            print(f"Shift intensity : [{i+1}]")
            testloader = torch.utils.data.DataLoader(
                testset[i*10000: (i+1)*10000], batch_size=batch_size, shuffle=False)
            loaders[f'{i+1}'].append(testloader)
    return loaders


def class_imbalance_sampler(labels):
    class_count = torch.bincount(labels.squeeze())
    class_weighting = 1. / class_count
    sample_weights = class_weighting[labels]
    sampler = WeightedRandomSampler(sample_weights, len(labels))
    return sampler


# * from https://github.com/p-lambda/wilds.
def get_train_loader(loader, dataset, batch_size,
                     uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, shuffle=True, sampler=None, **loader_kwargs):
    """
    Constructs and returns the data loader for training.
    Args:
        - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders,
                        which first samples groups and then samples a fixed number of examples belonging
                        to each group.
        - dataset (WILDSDataset or WILDSSubset): Data
        - batch_size (int): Batch size
        - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according
                                              to the natural data distribution.
                                              Setting to None applies the defaults for each type of loaders.
                                              For standard loaders, the default is False. For group loaders,
                                              the default is True.
        - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True
        - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders.
        - n_groups_per_batch (int): Number of groups to sample in each minibatch for group loaders.
        - loader_kwargs: kwargs passed into torch DataLoader initialization.
    Output:
        - data loader (DataLoader): Data loader.
    """
    if loader == 'standard':
        if uniform_over_groups is None or not uniform_over_groups:
            if sampler != None:
                shuffle = False
            return DataLoader(
                dataset,
                shuffle=shuffle,  # Shuffle training dataset
                sampler=sampler,
                collate_fn=dataset.collate,
                batch_size=batch_size,
                **loader_kwargs)
        else:
            groups, group_counts = grouper.metadata_to_group(
                dataset.metadata_array,
                return_counts=True)
            group_weights = 1 / group_counts
            weights = group_weights[groups]

            # Replacement needs to be set to True, otherwise we'll run out of minority samples
            sampler = WeightedRandomSampler(
                weights, len(dataset), replacement=True)
            return DataLoader(
                dataset,
                shuffle=False,  # The WeightedRandomSampler already shuffles
                sampler=sampler,
                collate_fn=dataset.collate,
                batch_size=batch_size,
                **loader_kwargs)

    elif loader == 'group':
        if uniform_over_groups is None:
            uniform_over_groups = True
        assert grouper is not None
        assert n_groups_per_batch is not None
        if n_groups_per_batch > grouper.n_groups:
            raise ValueError(
                f'n_groups_per_batch was set to {n_groups_per_batch} but there are only {grouper.n_groups} groups specified.')

        group_ids = grouper.metadata_to_group(dataset.metadata_array)
        batch_sampler = GroupSampler(
            group_ids=group_ids,
            batch_size=batch_size,
            n_groups_per_batch=n_groups_per_batch,
            uniform_over_groups=uniform_over_groups,
            distinct_groups=distinct_groups)

        return DataLoader(dataset,
                          shuffle=None,
                          sampler=None,
                          collate_fn=dataset.collate,
                          batch_sampler=batch_sampler,
                          drop_last=False,
                          **loader_kwargs)
