from keras import datasets
import os
import numpy as np
import tensorflow as tf
import scipy.io as sio
from keras.utils import to_categorical

def load_cifar_corruption(corrupt_no, serverity_level, show_corruption=None):
    corrupt_list = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 
                    'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
                    'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise',
                    'snow', 'spatter', 'speckle_noise', 'zoom_blur']
    
    full_path = 'test_core/CIFAR10_C/'
    labels = np.load(full_path + 'labels.npy', allow_pickle=True)
    inputs = np.load(full_path + corrupt_list[corrupt_no] + '.npy', allow_pickle=True)
    
    s = serverity_level
    if s in [1, 2, 3, 4, 5]:
        start_index = (s - 1) * 10000
        x = inputs[start_index:start_index + 10000] if s != 5 else inputs[start_index:]
        y = labels[start_index:start_index + 10000] if s != 5 else labels[start_index:]
    else:
        raise ValueError("Invalid value of Serverity Level")
    y = to_categorical(y, 10)
    if show_corruption:
        print('Corruption Type: ', corrupt_list[corrupt_no])
    
    x = x / 255.0
    x = x.astype('float32')
    
    # val_samples = -100
    
    # x = x[val_samples:]
    # y = y[val_samples:]
    return x, y

def load_cifar100_corruption(corrupt_no, serverity_level, show_corruption=None):
    corrupt_list = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 
                    'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
                    'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise',
                    'snow', 'spatter', 'speckle_noise', 'zoom_blur']
    
    full_path = 'test_core/CIFAR100_C/'
    labels = np.load(full_path + 'labels.npy', allow_pickle=True)
    inputs = np.load(full_path + corrupt_list[corrupt_no] + '.npy', allow_pickle=True)
    
    s = serverity_level
    if s in [1, 2, 3, 4, 5]:
        start_index = (s - 1) * 10000
        x = inputs[start_index:start_index + 10000] if s != 5 else inputs[start_index:]
        y = labels[start_index:start_index + 10000] if s != 5 else labels[start_index:]
    else:
        raise ValueError("Invalid value of Serverity Level")
    y = to_categorical(y, 100)
    if show_corruption:
        print('Corruption Type: ', corrupt_list[corrupt_no])
    
    x = x / 255.0
    x = x.astype('float32')
    
    # val_samples = -100
    
    # x = x[val_samples:]
    # y = y[val_samples:]
    return x, y

def load_cinic10_test():
    
    full_path = 'test_core/CINIC10/'
    labels = np.load(full_path + 'y_test.npy', allow_pickle=True)
    inputs = np.load(full_path + 'x_test.npy', allow_pickle=True)

    # val_samples = -100
    
    # inputs = inputs[val_samples:]
    # labels = labels[val_samples:]    
    
    return inputs, labels


def load_svhn_test():

    test_data = sio.loadmat('test_core/SVHN/test_32x32.mat')
    x_test = np.transpose(test_data['X'], (3, 0, 1, 2))
    y_test = test_data['y']
    # SVHN labels are 1-indexed, convert 10 to 0
    y_test[y_test == 10] = 0
    x_test = x_test / 255.0
    x_test = x_test.astype('float32')
    y_test = to_categorical(y_test, 10)
    
    x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
    
    # val_samples = -100
    
    # x_test = x_test[val_samples:]
    # y_test = y_test[val_samples:]
    return x_test, y_test

def load_cifar10():
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    # Standardizing
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
    x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
    # val_samples = -100
    
    # x_test = x_test[val_samples:]
    # y_test = y_test[val_samples:]

    return (x_train, y_train), (x_test, y_test)

def load_tinyimage_test():
    x_test = np.load('test_core/TinyImage/x_test.npy')
    y_test = np.load('test_core/TinyImage/y_test.npy')
    
    x_test = x_test / 255.0
    x_test = x_test.astype('float32')
    y_test = to_categorical(y_test, 200)

    # Calculate the mean and std along each axis
    mean = np.mean(x_test, axis=(0, 1, 2)) 
    std = np.std(x_test, axis=(0, 1, 2)) 

    # print(mean)
    # print(std)
    
    x_test = (x_test - mean) / std
    # val_samples = -100
    
    # x_test = x_test[val_samples:]
    # y_test = y_test[val_samples:]
    
    return x_test, y_test