import numpy as np
import tensorflow_datasets as tfds
from torchvision import datasets, transforms
import torch, os
import pickle

def load_cifar10():
  """Return the training and test datasets, as jnp.array's."""
  train_ds_images_u8, train_ds_labels = tfds.as_numpy(
      tfds.load("cifar10", split="train", batch_size=-1, as_supervised=True))
  test_ds_images_u8, test_ds_labels = tfds.as_numpy(
      tfds.load("cifar10", split="test", batch_size=-1, as_supervised=True))
  train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels}
  test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels}
  return train_ds, test_ds

import tensorflow_datasets as tfds

def load_cifar100(data_dir=None):
    # Define paths for saved dataset
    train_pickle = os.path.join(data_dir, "cifar100_train.pkl")
    test_pickle = os.path.join(data_dir, "cifar100_test.pkl")

    # Check if dataset already exists
    if os.path.exists(train_pickle) and os.path.exists(test_pickle):
        print(f"Loading dataset from {data_dir}")
        with open(train_pickle, "rb") as f:
            train_ds = pickle.load(f)
        with open(test_pickle, "rb") as f:
            test_ds = pickle.load(f)
        return train_ds, test_ds

    os.makedirs(data_dir, exist_ok=True)

    # Load CIFAR-100 dataset with optional data_dir
    train_ds_images_u8, train_ds_labels = tfds.as_numpy(
        tfds.load("cifar100", split="train", batch_size=-1, as_supervised=True, data_dir=data_dir))
    test_ds_images_u8, test_ds_labels = tfds.as_numpy(
        tfds.load("cifar100", split="test", batch_size=-1, as_supervised=True, data_dir=data_dir))
    
    # Organize datasets into dictionaries
    train_ds = {"images_u8": train_ds_images_u8, "labels": train_ds_labels}
    test_ds = {"images_u8": test_ds_images_u8, "labels": test_ds_labels}

    # Save datasets to pickle files
    print(f"Saving dataset to {data_dir}")
    try:
        with open(train_pickle, "wb") as f:
            pickle.dump(train_ds, f)
        with open(test_pickle, "wb") as f:
            pickle.dump(test_ds, f)
        print(f"Successfully saved datasets to {data_dir}")
    except Exception as e:
        print(f"Error saving datasets: {e}")
        raise

    return train_ds, test_ds

def _split_cifar(train_ds, label_split: int):
  """Split a CIFAR-ish dataset into two biased subsets."""
  assert train_ds["images_u8"].shape[0] == 50_000
  assert train_ds["labels"].shape[0] == 50_000

  # We randomly permute the training data, just in case there's some kind of
  # non-iid ordering coming out of tfds.
  perm = np.random.default_rng(123).permutation(50_000)
  train_images_u8 = train_ds["images_u8"][perm, :, :, :]
  train_labels = train_ds["labels"][perm]

  # This just so happens to be a clean 25000/25000 split.
  lt_images_u8 = train_images_u8[train_labels < label_split]
  lt_labels = train_labels[train_labels < label_split]
  gte_images_u8 = train_images_u8[train_labels >= label_split]
  gte_labels = train_labels[train_labels >= label_split]
  s1 = {
      "images_u8": np.concatenate((lt_images_u8[:5000], gte_images_u8[5000:]), axis=0),
      "labels": np.concatenate((lt_labels[:5000], gte_labels[5000:]), axis=0)
  }
  s2 = {
      "images_u8": np.concatenate((gte_images_u8[:5000], lt_images_u8[5000:]), axis=0),
      "labels": np.concatenate((gte_labels[:5000], lt_labels[5000:]), axis=0)
  }
  return s1, s2

def load_cifar10_split():
  train_ds, test_ds = load_cifar10()
  s1, s2 = _split_cifar(train_ds, label_split=5)
  return s1, s2, test_ds

def load_cifar100_split():
  train_ds, test_ds = load_cifar100()
  s1, s2 = _split_cifar(train_ds, label_split=50)
  return s1, s2, test_ds