import random
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split
from tensorflow.keras import datasets
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tinyimagenet import *

def get_data(dataset,path_to_data=None):
  aug_config={}
  if dataset.lower()=='mnist':
    (train_X,train_y),(test_X,test_y)=datasets.mnist.load_data()
    #aug_config['rotation_range']=4
    mean,std=0.1307,0.3081
    num_classes=10
    # Check sample sizes of split
    print('Train data shape: ', train_X.shape)
    print('Train labels shape: ', train_y.shape)
    print('Test data shape: ', test_X.shape)
    print('Test labels shape: ', test_y.shape)
  if dataset.lower()=='fmnist':
    (train_X,train_y),(test_X,test_y)=datasets.fashion_mnist.load_data()
    #aug_config['rotation_range']=4
    mean,std=0.286156,0.35314866
    num_classes=10
  if dataset.lower()=='kmnist':
    train_dataset, info_train = tfds.load('kmnist', split='train', as_supervised=True, with_info=True)
    test_dataset, info_test = tfds.load('kmnist', split='test', as_supervised=True, with_info=True)
    num_classes = 10
    # Extract images and labels from train dataset
    X_train, Y_train = [], []
    for image, label in train_dataset:
        X_train.append(image.numpy())
        Y_train.append(label.numpy())
    # Extract images and labels from test dataset
    X_test, Y_test = [], []
    for image, label in test_dataset:
        X_test.append(image.numpy())
        Y_test.append(label.numpy())
    # Convert lists to NumPy arrays
    X_train, Y_train = np.array(X_train), np.array(Y_train)
    X_test, Y_test = np.array(X_test), np.array(Y_test)
    # Create train/val/test split to match with hadamard training routine
    np.random.seed(42)
    random.seed(42)
    #X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.1,shuffle = True)
    (train_X,train_y),(test_X,test_y)=(X_train,Y_train),(X_test,Y_test)
    # Check sample sizes of split
    print('Train data shape: ', train_X.shape)
    print('Train labels shape: ', train_y.shape)
    print('Test data shape: ', test_X.shape)
    print('Test labels shape: ', test_y.shape)
    mean,std=0.19173359,0.34831141
  elif dataset.lower()=='cifar10':
    (train_X,train_y),(test_X,test_y)=datasets.cifar10.load_data()
    aug_config['horizontal_flip']=True
    aug_config['width_shift_range']=aug_config['height_shift_range']=4
    mean,std=[0.4914,0.4822,0.4465],[0.2470,0.2435,0.2616]
    num_classes=10
  elif dataset.lower()=='cifar100':
    aug_config['horizontal_flip']=True
    aug_config['width_shift_range']=aug_config['height_shift_range']=4
    (train_X,train_y),(test_X,test_y)=datasets.cifar100.load_data()
    mean,std=[0.5071,0.4865,0.4409],[0.2673,0.2564,0.2762]
    num_classes=100
  elif dataset.lower()=='tinyimagenet':
    aug_config['horizontal_flip']=True
    aug_config['width_shift_range']=aug_config['height_shift_range']=4
    train_X,train_y,test_X,test_y=load_images(path_to_data+'/tiny-imagenet-200',200)
    train_X,test_X=np.float32(np.transpose(train_X,axes=(0,2,3,1))),np.float32(np.transpose(test_X,axes=(0,2,3,1))) 
    mean,std=[0.4802,0.4481,0.3975],[0.2770,0.2691,0.2821]
    num_classes=200
  
  # Further processing of loaded data
  train_y=tf.keras.utils.to_categorical(train_y,num_classes)
  test_y=tf.keras.utils.to_categorical(test_y,num_classes)
  train_X,test_X=np.divide(train_X,255),np.divide(test_X,255)
  train_X,test_X=(train_X-mean)/std,(test_X-mean)/std
  train_X=np.expand_dims(train_X,axis=3) if len(train_X.shape)<4 else train_X # adding channel dimensions when it's not there
  test_X=np.expand_dims(test_X,axis=3) if len(test_X.shape)<4 else test_X
  datagen=ImageDataGenerator(**aug_config)
  datagen.fit(train_X)
  return datagen,train_X,train_y,test_X,test_y