import numpy as np
from PIL import Image

import torch
from torchvision import datasets

from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy


class WILDSNoSpur(WILDSDataset):
    def __init__(self, version=None, root_dir='', download=True, split_scheme='official', val_pct=0.):

        self._data_dir = self.initialize_data_dir(root_dir)

        cifar_train = datasets.cifar.CIFAR10(root = self._data_dir, train=True, download=download)
        cifar_test = datasets.cifar.CIFAR10(root = self._data_dir, train=False, download=download)

        train_X, train_Y = np.array(cifar_train.data), np.array(cifar_train.targets)
        test_X, test_Y = np.array(cifar_test.data), np.array(cifar_test.targets)

        classes = np.sort(np.unique(train_Y))

        self.X = np.concatenate((train_X, test_X))

        # Get the y values
        self._y_array = torch.from_numpy(np.concatenate((train_Y, test_Y))).long()
        self._y_size = 1
        self._n_classes = len(classes)

        self._metadata_array = self._y_array.reshape(-1, 1)

        self._metadata_fields = ['y']
        self._metadata_map = {'y': cifar_train.classes}

        self._original_resolution = (32, 32)

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')

        val_idxs = np.random.choice(np.arange(len(train_X)), size = int(len(train_X) * val_pct), replace = False)
        self._split_array = np.zeros((len(train_X) + len(test_X), 1))
        self._split_array[val_idxs] = 1
        self._split_array[len(train_X):] = 2
        
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['y']))

        super().__init__(self._data_dir, split_scheme)


    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        metric = Accuracy(prediction_fn=prediction_fn)

        results, results_str = self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)

        return results, results_str

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       return Image.fromarray(self.X[idx])


class WILDSCIFAR10(WILDSDataset):
    _dataset_name = 'cifar10'

    def __init__(self, version=None, root_dir='', download=True, split_scheme='official', val_pct=0.):

        self._data_dir = self.initialize_data_dir('')

        cifar_train = datasets.cifar.CIFAR10(root = self._data_dir, train=True, download=download)
        cifar_test = datasets.cifar.CIFAR10(root = self._data_dir, train=False, download=download)

        train_X, train_Y = np.array(cifar_train.data), np.array(cifar_train.targets)
        test_X, test_Y = np.array(cifar_test.data), np.array(cifar_test.targets)

        classes = np.sort(np.unique(train_Y))

        self.X = np.concatenate((train_X, test_X))

        # Get the y values
        self._y_array = torch.from_numpy(np.concatenate((train_Y, test_Y))).long()
        self._y_size = 1
        self._n_classes = len(classes)

        self._metadata_array = self._y_array.reshape(-1, 1)

        self._metadata_fields = ['y']
        self._metadata_map = {'y': cifar_train.classes}

        self._original_resolution = (32, 32)

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')

        val_idxs = np.random.choice(np.arange(len(train_X)), size = int(len(train_X) * val_pct), replace = False)
        self._split_array = np.zeros((len(train_X) + len(test_X), 1))
        self._split_array[val_idxs] = 1
        self._split_array[len(train_X):] = 2
        
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['y']))

        super().__init__(self._data_dir, split_scheme)
    
    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        metric = Accuracy(prediction_fn=prediction_fn)

        results, results_str = self.standard_group_eval(
            metric,
            self._eval_grouper,
            y_pred, y_true, metadata)

        return results, results_str

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       return Image.fromarray(self.X[idx])


class WILDSIN(WILDSNoSpur):
    _dataset_name = 'imagenet'

    def __init__(self, version=None, root_dir='', download=True, split_scheme='official', val_pct=0.):
        
        self._data_dir = self.initialize_data_dir(root_dir=None, download=download)
        imagenet = datasets.ImageNet(self._data_dir, split='train')
        imagenet_val = datasets.ImageNet(self._data_dir, split='val')
