import numpy as np
import tensorflow_datasets as tfds
from torchvision import datasets, transforms
import torch, os
import pickle

def load_cifar10(data_dir=None):
  """Return the training and test datasets, as jnp.array's."""
  train_ds_images_u8, train_ds_labels = tfds.as_numpy(
      tfds.load("cifar10", split="train", batch_size=-1, as_supervised=True, data_dir=data_dir))
  test_ds_images_u8, test_ds_labels = tfds.as_numpy(
      tfds.load("cifar10", split="test", batch_size=-1, as_supervised=True, data_dir=data_dir))
  train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels}
  test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels}
  return train_ds, test_ds

import tensorflow_datasets as tfds

def load_cifar100(data_dir=None):
    # Define paths for saved dataset
    train_pickle = os.path.join(data_dir, "cifar100_train.pkl")
    test_pickle = os.path.join(data_dir, "cifar100_test.pkl")

    # Check if dataset already exists
    if os.path.exists(train_pickle) and os.path.exists(test_pickle):
        print(f"Loading dataset from {data_dir}")
        with open(train_pickle, "rb") as f:
            train_ds = pickle.load(f)
        with open(test_pickle, "rb") as f:
            test_ds = pickle.load(f)
        return train_ds, test_ds

    os.makedirs(data_dir, exist_ok=True)

    # Load CIFAR-100 dataset with optional data_dir
    train_ds_images_u8, train_ds_labels = tfds.as_numpy(
        tfds.load("cifar100", split="train", batch_size=-1, as_supervised=True, data_dir=data_dir))
    test_ds_images_u8, test_ds_labels = tfds.as_numpy(
        tfds.load("cifar100", split="test", batch_size=-1, as_supervised=True, data_dir=data_dir))
    
    # Organize datasets into dictionaries
    train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels}
    test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels}

    # Save datasets to pickle files
    print(f"Saving dataset to {data_dir}")
    try:
        with open(train_pickle, "wb") as f:
            pickle.dump(train_ds, f)
        with open(test_pickle, "wb") as f:
            pickle.dump(test_ds, f)
        print(f"Successfully saved datasets to {data_dir}")
    except Exception as e:
        print(f"Error saving datasets: {e}")
        raise

    return train_ds, test_ds