import logging
import os

import numpy as np
import scipy.fft
from skimage.filters import gaussian
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
import joblib
from tqdm import tqdm

from autovar.base import RegisteringChoiceType, register_var, VariableClass

logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
                    level=logging.WARNING, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)

def get_mnist():
    from tensorflow.keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train[:, :, :, np.newaxis], x_test[:, :, :, np.newaxis]
    x_train, x_test = x_train.astype(np.float32) / 255, x_test.astype(np.float32) / 255
    return x_train, y_train, x_test, y_test, np.empty(0)

def get_fashion():
    from tensorflow.keras.datasets import fashion_mnist
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    x_train, y_train, x_test, y_test = np.copy(x_train), np.copy(y_train), np.copy(x_test), np.copy(y_test)
    x_train.setflags(write=1)
    y_train.setflags(write=1)
    x_test.setflags(write=1)
    y_test.setflags(write=1)
    x_train, x_test = x_train[:, :, :, np.newaxis], x_test[:, :, :, np.newaxis]
    x_train, x_test = x_train.astype(np.float32) / 255, x_test.astype(np.float32) / 255
    return x_train, y_train, x_test, y_test, np.empty(0)

def get_cifar100(label_mode="fine"):
    from tensorflow.keras.datasets import cifar100
    (x_train, y_train), (x_test, y_test) = cifar100.load_data(label_mode)
    y_train, y_test = y_train.reshape(-1), y_test.reshape(-1)
    x_train, x_test = x_train.astype(np.float32) / 255, x_test.astype(np.float32) / 255
    return x_train, y_train, x_test, y_test, np.empty(0)

def get_cifar10():
    from tensorflow.keras.datasets import cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    y_train, y_test = y_train.reshape(-1), y_test.reshape(-1)
    x_train, x_test = x_train.astype(np.float32) / 255, x_test.astype(np.float32) / 255
    return x_train, y_train, x_test, y_test, np.empty(0)

def get_imgnet100(trn_transform, tst_transform):
    from torchvision.datasets import ImageFolder
    imgnet_trn_dir = "/tmp2/ImageNet100/ILSVRC2012_img_train/"
    imgnet_val_dir = "/tmp2/ImageNet100/ILSVRC2012_img_val/"
    return (
        ImageFolder(imgnet_trn_dir, transform=trn_transform),
        ImageFolder(imgnet_val_dir, transform=tst_transform),
    )


def add_spurious_correlation(X, version, seed):
    assert len(X.shape) == 4, X.shape
    if not version or version == "v1":
        X[:, 0, 0] = 1
    elif version == "v2":
        X[:, 0, :] = 1
    elif version == "v3":
        X[:, :3, :3] = 1
    elif version == "v4":
        X[:, 0, :1] = X[:, 1, :2] = X[:, 2, :3] = 1
    elif version == "v5":
        X[:, :3, 1] = 1
    elif version == "v8":
        X[:, :5, :5] = 1
    elif version == "v9":
        X += 0.1
        X = np.clip(X, 0, 1)
    elif version == "v10":
        X += 0.3
        X = np.clip(X, 0, 1)
    elif version == "v11":
        X += 0.5
        X = np.clip(X, 0, 1)

    elif version == "v6":
        fftX = scipy.fft.fft2(X, axes=[1, 2])
        fftX[:, 0, 0, 0] = 0
        X = scipy.fft.ifft2(fftX, axes=[1, 2])
        X = np.real(X).astype(np.float32)
        X = np.clip(X, 0, 1)
    elif version == "v7":
        fftX = scipy.fft.fft2(X, axes=[1, 2])
        fftX[:, :3, :3, 0] = 0
        X = scipy.fft.ifft2(fftX, axes=[1, 2])
        X = np.real(X).astype(np.float32)
        X = np.clip(X, 0, 1)

    elif version == "vgau1":
        ret = np.zeros_like(X)
        for i, x in enumerate(X):
            ret[i] = gaussian(x)
        X = np.clip(ret, 0, 1)
    elif version == "vgau2":
        ret = np.zeros_like(X)
        for i, x in enumerate(X):
            ret[i] = gaussian(x, sigma=2)
        X = np.clip(ret, 0, 1)
    elif version == "vgau4":
        ret = np.zeros_like(X)
        for i, x in enumerate(X):
            ret[i] = gaussian(x, sigma=4)
        X = np.clip(ret, 0, 1)

    elif version == "v18":
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        ret = X + 0.25 * noise.reshape(1, *X.shape[1:])
        X = np.clip(ret, 0, 1)

    elif version == "v19":
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        ret = X + 0.5 * noise.reshape(1, *X.shape[1:])
        X = np.clip(ret, 0, 1)

    elif version == "v20":
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        ret = X + noise.reshape(1, *X.shape[1:])
        X = np.clip(ret, 0, 1)

    elif version == "v21":
        keep_fraction = 0.1
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        noise_fft = scipy.fft.fft2(noise, axes=[0, 1])
        noise_fft[:int((X.shape[1] * (1-keep_fraction)) // 2), int(-(X.shape[2] * (1-keep_fraction)) // 2):] = 0
        noise_fft[int(-(X.shape[1] * (1-keep_fraction)) // 2):, :int((X.shape[2] * (1-keep_fraction)) // 2)] = 0
        noise_fft[:int((X.shape[1] * (1-keep_fraction)) // 2), :int((X.shape[2] * (1-keep_fraction)) // 2)] = 0
        noise_fft[int(-(X.shape[1] * (1-keep_fraction)) // 2):, int(-(X.shape[2] * (1-keep_fraction)) // 2):] = 0
        noise = scipy.fft.ifft2(noise_fft, axes=[0, 1])

        ret = X + noise.reshape(1, *X.shape[1:]).real
        X = np.clip(ret, 0, 1)

    elif version == "v30":
        X[:, 3:25, 13:16, :] = 1
    else:
        raise ValueError(f"version: {version} not supported")
    return X

def add_colored_spurious_correlation(X, version, seed):
    if version == "v1":
        X[:, 0, 0, :] = 1
    elif version == "v3":
        X[:, :3, :3, :] = 1
    elif version == "v8":
        X[:, :5, :5, :] = 1
    elif version == "v81":
        X[:, :10, :10, :] = 1

    elif version == "v6":
        fftX = scipy.fft.fft2(X, axes=[1, 2])
        fftX[:, 0, 0, :] = 0
        X = scipy.fft.ifft2(fftX, axes=[1, 2])
        X = np.real(X).astype(np.float32)
        X = np.clip(X, 0, 1)
    elif version == "v7":
        fftX = scipy.fft.fft2(X, axes=[1, 2])
        fftX[:, :3, :3, :] = 0
        X = scipy.fft.ifft2(fftX, axes=[1, 2])
        X = np.real(X).astype(np.float32)
        X = np.clip(X, 0, 1)

    elif version == "v9":
        X += 0.1
        X = np.clip(X, 0, 1)
    elif version == "v10":
        X += 0.3
        X = np.clip(X, 0, 1)
    elif version == "v11":
        X += 0.5
        X = np.clip(X, 0, 1)

    elif version == "v18":
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        ret = X + 0.25 * noise.reshape(1, *X.shape[1:])
        X = np.clip(ret, 0, 1)

    elif version == "v19":
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        ret = X + 0.5 * noise.reshape(1, *X.shape[1:])
        X = np.clip(ret, 0, 1)

    elif version == "v20":
        random_state = np.random.RandomState(seed)
        noise = random_state.rand(*X.shape[1:])
        ret = X + noise.reshape(1, *X.shape[1:])
        X = np.clip(ret, 0, 1)

    elif version == "vgau1":
        ret = np.zeros_like(X)
        for i, x in enumerate(X):
            ret[i] = gaussian(x)
        X = np.clip(ret, 0, 1)
    elif version == "vgau2":
        ret = np.zeros_like(X)
        for i, x in enumerate(X):
            #ret[i] = gaussian(x, sigma=2, channel_axis=2)
            ret[i] = gaussian(x, sigma=2)
        X = np.clip(ret, 0, 1)

    elif version == "v30":
        X[:, 3:25, 13:16, :] = 1

    else:
        raise ValueError(f"version: {version} not supported")
    return X

def get_program_data(data_dir):
    X, y, colnames = joblib.load(os.path.join(data_dir, "program_features.pkl"))
    Xspu, _, spu_colnames = joblib.load(os.path.join(data_dir, "program_features_spurious.pkl"))
    final_colnames = []
    t, tspu, t2, t3 = [], [], [], []
    for i in range(len(colnames)):
        if colnames[i] in spu_colnames:
            t.append(i)
            tspu.append(np.where(spu_colnames == colnames[i])[0][0])
        else:
            t2.append(i)
        final_colnames.append(colnames[i])
    for i in range(len(spu_colnames)):
        if spu_colnames[i] not in colnames:
            t3.append(i)
            final_colnames.append(spu_colnames[i])
    t, tspu, t2, t3 = np.array(t), np.array(tspu), np.array(t2), np.array(t3)

    X = np.concatenate((X[:, t], X[:, t2], np.zeros((X.shape[0], len(t3)))), axis=1)
    Xspu = np.concatenate((Xspu[:, tspu], np.zeros((Xspu.shape[0], len(t2))), Xspu[:, t3]), axis=1)
    return X, Xspu, y, final_colnames


class DatasetVarClass(VariableClass, metaclass=RegisteringChoiceType):
    """Defines the dataset to use"""
    var_name = 'dataset'

    @register_var(argument=r"oriprogram-(?P<foldno>[0-9]+)", shown_name="mnist")
    @staticmethod
    def oriprogram(auto_var, foldno, data_dir="./data/"):
        from sklearn.preprocessing import MinMaxScaler
        foldno = int(foldno)
        X, y, _ = joblib.load(os.path.join(data_dir, "program_features.pkl"))

        skf = StratifiedKFold(n_splits=8, random_state=0, shuffle=True)
        trn_idx, tst_idx = [(trn_idx, tst_idx) for trn_idx, tst_idx in skf.split(X, y)][foldno]
        trnX, trny, tstX, tsty = X[trn_idx], y[trn_idx], X[tst_idx], y[tst_idx]
        scaler = MinMaxScaler()
        trnX = scaler.fit_transform(trnX)
        tstX = scaler.transform(tstX)
        return trnX, trny, tstX, tsty, np.empty(0)

    @register_var(argument=r"program-(?P<foldno>[0-9]+)", shown_name="mnist")
    @staticmethod
    def program(auto_var, foldno, data_dir="./data/"):
        from sklearn.preprocessing import MinMaxScaler
        foldno = int(foldno)
        X, _, y, _ = get_program_data(data_dir)

        skf = StratifiedKFold(n_splits=8, random_state=0, shuffle=True)
        trn_idx, tst_idx = [(trn_idx, tst_idx) for trn_idx, tst_idx in skf.split(X, y)][foldno]
        trnX, trny, tstX, tsty = X[trn_idx], y[trn_idx], X[tst_idx], y[tst_idx]
        scaler = MinMaxScaler()
        trnX = scaler.fit_transform(trnX)
        tstX = scaler.transform(tstX)
        return trnX, trny, tstX, tsty, np.empty(0)

    @register_var(argument=r"program(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)-(?P<foldno>[0-9]+)", shown_name="mnist")
    @staticmethod
    def program_aug(auto_var, version, sp_counts, cls_no, seed, foldno, data_dir="./data/"):
        from sklearn.preprocessing import MinMaxScaler
        foldno = int(foldno)
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        assert version == "v1"
        X, Xspu, y, _ = get_program_data(data_dir)

        skf = StratifiedKFold(n_splits=8, random_state=0, shuffle=True)
        trn_idx, tst_idx = [(trn_idx, tst_idx) for trn_idx, tst_idx in skf.split(X, y)][foldno]
        trnX, trny, tstX, tsty = X[trn_idx], y[trn_idx], X[tst_idx], y[tst_idx]

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(trny == cls_no)[0], size=sp_counts, replace=False)
        trnX[spurious_ind] = Xspu[spurious_ind]

        scaler = MinMaxScaler()
        trnX = scaler.fit_transform(trnX)
        tstX = scaler.transform(tstX)
        return trnX, trny, tstX, tsty, spurious_ind

    @register_var(argument=r"mnist", shown_name="mnist")
    @staticmethod
    def mnist(auto_var):
        return get_mnist()

    @register_var(argument=r"mnist(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def mnist_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] MNIST, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_mnist()

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(y_train == cls_no)[0], size=sp_counts, replace=False)
        x_train[spurious_ind] = add_spurious_correlation(x_train[spurious_ind], version, seed)

        return x_train, y_train, x_test, y_test, spurious_ind

    @register_var(argument=r"memmnist", shown_name="mnist")
    @staticmethod
    def memmnist(auto_var, seed=0):
        logger.info(f"[dataset] memMNIST")
        x_train, y_train, x_test, y_test, _ = get_mnist()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, \
               nonmem_tstX, nonmem_tsty, np.empty(0)

    @register_var(argument=r"memmnist(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def mem_mnist_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] MNIST, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_mnist()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(mem_trny == cls_no)[0], size=sp_counts, replace=False)
        mem_trnX[spurious_ind] = add_spurious_correlation(mem_trnX[spurious_ind], version, seed)

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, \
               nonmem_tstX, nonmem_tsty, spurious_ind

    @register_var(argument=r"spmemmnist(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def spmem_mnist_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] MNIST, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_mnist()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        random_state = np.random.RandomState(seed)
        spu_ind = []
        spurious_ind = random_state.choice(np.where(mem_trny == cls_no)[0], size=sp_counts, replace=False)
        spu_ind.append(spurious_ind)
        mem_trnX[spurious_ind] = add_spurious_correlation(mem_trnX[spurious_ind], version, seed)
        spurious_ind = random_state.choice(np.where(nonmem_trny == cls_no)[0], size=sp_counts, replace=False)
        spu_ind.append(spurious_ind)
        nonmem_trnX[spurious_ind] = add_spurious_correlation(nonmem_trnX[spurious_ind], version, seed)

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, \
               nonmem_tstX, nonmem_tsty, spu_ind

    @register_var(argument=r"memfashion", shown_name="mnist")
    @staticmethod
    def memfashion(auto_var, seed=0):
        logger.info(f"[dataset] memFashion")
        x_train, y_train, x_test, y_test, _ = get_fashion()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, nonmem_tstX, nonmem_tsty, np.empty(0)

    @register_var(argument=r"memfashion(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def mem_fashion_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] Fashion, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_fashion()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(mem_trny == cls_no)[0], size=sp_counts, replace=False)
        mem_trnX[spurious_ind] = add_spurious_correlation(mem_trnX[spurious_ind], version, seed)

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, nonmem_tstX, nonmem_tsty, spurious_ind

    @register_var(argument=r"spmemfashion(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def spmem_fashion_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] Fashion, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_fashion()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        random_state = np.random.RandomState(seed)
        spu_ind = []
        spurious_ind = random_state.choice(np.where(mem_trny == cls_no)[0], size=sp_counts, replace=False)
        spu_ind.append(spurious_ind)
        mem_trnX[spurious_ind] = add_spurious_correlation(mem_trnX[spurious_ind], version, seed)
        spurious_ind = random_state.choice(np.where(nonmem_trny == cls_no)[0], size=sp_counts, replace=False)
        spu_ind.append(spurious_ind)
        nonmem_trnX[spurious_ind] = add_spurious_correlation(nonmem_trnX[spurious_ind], version, seed)

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, nonmem_tstX, nonmem_tsty, spu_ind

    @register_var(argument=r"oneothersmnist(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def oneothers_mnist_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] MNIST, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_mnist()

        random_state = np.random.RandomState(seed)
        for i in range(10):
            if i == cls_no:
                continue
            spurious_ind = random_state.choice(np.where(y_train == i)[0], size=1, replace=False)
            x_train[spurious_ind] = add_colored_spurious_correlation(x_train[spurious_ind], version, seed)

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(y_train == cls_no)[0], size=sp_counts, replace=False)
        x_train[spurious_ind] = add_spurious_correlation(x_train[spurious_ind], version, seed)

        return x_train, y_train, x_test, y_test, spurious_ind

    @register_var(argument=r"twoclassmnist(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no1>[0-9])-(?P<cls_no2>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def twoclasslmnist_aug(auto_var, version, sp_counts, cls_no1, cls_no2, seed):
        logger.info(f"[dataset] MNIST, {version}, {sp_counts}, {cls_no1}, {cls_no2}, {seed}")
        sp_counts, cls_no1, cls_no2, seed = int(sp_counts), int(cls_no1), int(cls_no2), int(seed)
        x_train, y_train, x_test, y_test, _ = get_mnist()

        random_state = np.random.RandomState(seed)
        spurious_ind1 = random_state.choice(np.where(y_train == cls_no1)[0], size=sp_counts, replace=False)
        x_train[spurious_ind1] = add_spurious_correlation(x_train[spurious_ind1], version, seed)
        spurious_ind2 = random_state.choice(np.where(y_train == cls_no2)[0], size=sp_counts, replace=False)
        x_train[spurious_ind2] = add_spurious_correlation(x_train[spurious_ind2], version, seed)

        return x_train, y_train, x_test, y_test, np.concatenate((spurious_ind1, spurious_ind2))

    @register_var(argument=r"fashion", shown_name="fashion mnist")
    @staticmethod
    def fashion(auto_var):
        return get_fashion()

    @register_var(argument=r"fashion(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def fashion_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] fashion, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_fashion()

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(y_train == cls_no)[0], size=sp_counts, replace=False)
        x_train[spurious_ind] = add_spurious_correlation(x_train[spurious_ind], version, seed)

        return x_train, y_train, x_test, y_test, spurious_ind

    @register_var(argument=r"oneothersfashion(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def oneothers_fashion_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] fashion, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_fashion()

        random_state = np.random.RandomState(seed)
        for i in range(10):
            if i == cls_no:
                continue
            spurious_ind = random_state.choice(np.where(y_train == i)[0], size=1, replace=False)
            x_train[spurious_ind] = add_colored_spurious_correlation(x_train[spurious_ind], version, seed)

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(y_train == cls_no)[0], size=sp_counts, replace=False)
        x_train[spurious_ind] = add_spurious_correlation(x_train[spurious_ind], version, seed)

        return x_train, y_train, x_test, y_test, spurious_ind

    @register_var(argument=r"cifar100", shown_name="Cifar100")
    @staticmethod
    def cifar100(auto_var):
        return get_cifar100("fine")

    @register_var(argument=r"cifar100coarse", shown_name="Cifar100")
    @staticmethod
    def cifar100coarse(auto_var):
        return get_cifar100("coarse")

    @register_var(argument=r"cifar10", shown_name="Cifar10")
    @staticmethod
    def cifar10(auto_var):
        return get_cifar10()

    @register_var(argument=r"cifar10(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def cifar10_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] CIFAR10, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_cifar10()

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(y_train == cls_no)[0], size=sp_counts, replace=False)
        x_train[spurious_ind] = add_colored_spurious_correlation(x_train[spurious_ind], version, seed)

        return x_train, y_train, x_test, y_test, spurious_ind

    @register_var(argument=r"memcifar10", shown_name="mnist")
    @staticmethod
    def memcifar10(auto_var, seed=0):
        logger.info(f"[dataset] memCifar10")
        x_train, y_train, x_test, y_test, _ = get_cifar10()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, nonmem_tstX, nonmem_tsty, np.empty(0)

    @register_var(argument=r"memcifar10(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def memcifar10_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] memCIFAR10, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_cifar10()
        X = np.concatenate((x_train, x_test), axis=0)
        y = np.concatenate((y_train, y_test))

        random_state = np.random.RandomState(seed)
        idx = np.arange(len(X))
        random_state.shuffle(idx)
        each = len(idx) // 4
        mem_trnX, mem_trny = X[idx[:each]], y[idx[:each]]
        mem_tstX, mem_tsty = X[idx[1*each:2*each]], y[idx[1*each:2*each]]
        nonmem_trnX, nonmem_trny = X[idx[2*each:3*each]], y[idx[2*each:3*each]]
        nonmem_tstX, nonmem_tsty = X[idx[3*each:]], y[idx[3*each:]]

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(mem_trny == cls_no)[0], size=sp_counts, replace=False)
        mem_trnX[spurious_ind] = add_colored_spurious_correlation(mem_trnX[spurious_ind], version, seed)

        return mem_trnX, mem_trny, mem_tstX, mem_tsty, nonmem_trnX, nonmem_trny, nonmem_tstX, nonmem_tsty, spurious_ind

    @register_var(argument=r"oneotherscifar10(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="mnist")
    @staticmethod
    def oneothers_cifar10_aug(auto_var, version, sp_counts, cls_no, seed):
        logger.info(f"[dataset] CIFAR10, {version}, {sp_counts}, {cls_no}, {seed}")
        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        x_train, y_train, x_test, y_test, _ = get_cifar10()

        random_state = np.random.RandomState(seed)
        for i in range(10):
            if i == cls_no:
                continue
            spurious_ind = random_state.choice(np.where(y_train == i)[0], size=1, replace=False)
            x_train[spurious_ind] = add_colored_spurious_correlation(x_train[spurious_ind], version, seed)

        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(y_train == cls_no)[0], size=sp_counts, replace=False)
        x_train[spurious_ind] = add_colored_spurious_correlation(x_train[spurious_ind], version, seed)

        return x_train, y_train, x_test, y_test, spurious_ind

    @register_var(argument=r"(?P<dataaug>[a-zA-Z0-9]+-)?imgnet100", shown_name="imgnet100")
    @staticmethod
    def imgnet100(auto_var, dataaug, trn_transform, tst_transform):
        from .imgnet100 import get_imgnet100
        from spurious_ml.models.torch_utils import data_augs
        if dataaug is None:
            trn_dset, tst_dset = get_imgnet100(None, None, None, None, None)
        else:
            pre_transform, _ = getattr(data_augs, dataaug[:-1])()
            trn_dset, tst_dset = get_imgnet100(trn_transform, tst_transform, pre_transform, None, None)
        return trn_dset, tst_dset, np.empty(0)

    #@register_var(argument=r"(?P<dataaug>[a-zA-Z0-9]+-)?imgnet100", shown_name="imgnet100")
    #@staticmethod
    #def imgnet100_aug(auto_var, dataaug):
    #    from spurious_ml.models.torch_utils import data_augs
    #    import torch
    #    torch.multiprocessing.set_sharing_strategy('file_system')
    #    logger.info(f"[dataset] Imgnet100, {dataaug}")

    #    if dataaug is None:
    #        trn_dset, tst_dset = get_imgnet100(None, None)
    #    else:
    #        trn_transform, tst_transform = getattr(data_augs, dataaug[:-1])()
    #        trn_dset, tst_dset = get_imgnet100(trn_transform, tst_transform)

    #    def load_data(dset):
    #        from torch.utils.data import DataLoader
    #        loader = DataLoader(dset, batch_size=32, shuffle=False, num_workers=24)
    #        X, y = zip(*[(x.numpy(), y.numpy()) for (x, y) in tqdm(loader)])
    #        return np.concatenate(X, axis=0).transpose(0, 2, 3, 1), np.concatenate(y)

    #    trnX, trny = load_data(trn_dset)
    #    tstX, tsty = load_data(tst_dset)

    #    return trn_dset, tst_dset, np.empty(0)

    @register_var(argument=r"(?P<dataaug>[a-zA-Z0-9]+-)?imgnet100(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)", shown_name="imgnet100")
    @staticmethod
    def imgnet100_aug(auto_var, dataaug, version, sp_counts, cls_no, seed, trn_transform, tst_transform):
        from .imgnet100 import get_imgnet100
        from spurious_ml.models.torch_utils import data_augs
        import torch
        from torchvision.datasets import ImageFolder
        from torch.utils.data import DataLoader
        from torchvision import transforms
        torch.multiprocessing.set_sharing_strategy('file_system')
        logger.info(f"[dataset] Imgnet100, {version}, {sp_counts}, {cls_no}, {seed}")

        imgnet_trn_dir = "/tmp2/ImageNet100/ILSVRC2012_img_train/"
        temp_transform = transforms.Compose([transforms.CenterCrop(1), transforms.ToTensor()])
        dset = ImageFolder(imgnet_trn_dir, transform=temp_transform)
        loader = DataLoader(dset, batch_size=64, shuffle=False, num_workers=24)
        trny = np.concatenate([y.numpy() for (_, y) in tqdm(loader)])

        pre_transform, _ = getattr(data_augs, dataaug[:-1])()

        sp_counts, cls_no, seed = int(sp_counts), int(cls_no), int(seed)
        random_state = np.random.RandomState(seed)
        spurious_ind = random_state.choice(np.where(trny == cls_no)[0], size=sp_counts, replace=False)
        spurious_fn = lambda x: add_colored_spurious_correlation(x[None, :, :, :], version, seed)

        trn_dset, tst_dset = get_imgnet100(trn_transform, tst_transform, pre_transform=pre_transform,
                                           spurious_ind=spurious_ind, spurious_fn=spurious_fn)

        return trn_dset, tst_dset, spurious_ind
