import tensorflow_datasets as tfds
import sys
import jax
import matplotlib
import seaborn as sns
from matplotlib import pyplot as plt

sys.path.append("google-research/")

train_dataset, test_dataset = tfds.load(name="cifar10_corrupted")
