# this file provides utilities for loading the various datasets used in the paper

import gzip
import struct
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.datasets import load_digits

def load_mnist_images(filename):
    with gzip.open(filename, 'rb') as f:
        # First 16 bytes: magic number, number of images, rows, cols
        magic, num_images, rows, cols = struct.unpack(">IIII", f.read(16))
        # Read the rest of the data
        image_data = f.read()
        images = np.frombuffer(image_data, dtype=np.uint8)
        images = images.reshape((num_images, rows, cols))
        return images

def load_mnist_labels(filename):
    with gzip.open(filename, 'rb') as f:
        # First 8 bytes: magic number, number of labels
        magic, num_labels = struct.unpack(">II", f.read(8))
        # Read the rest of the data
        label_data = f.read()
        labels = np.frombuffer(label_data, dtype=np.uint8)
        return labels

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# Helper function to read the raw IDX files
def read_idx_file(filename):
    with gzip.open(filename, 'rb') as f:
        # Read the magic number to check file type
        magic_number = struct.unpack(">I", f.read(4))[0]
        
        # If it's an image file (magic number 2051), expect:
        if magic_number == 2051:
            num_images = struct.unpack(">I", f.read(4))[0]
            rows = struct.unpack(">I", f.read(4))[0]
            cols = struct.unpack(">I", f.read(4))[0]
            # Read the image data
            images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, rows, cols)
            return images
        
        # If it's a label file (magic number 2049), expect:
        elif magic_number == 2049:
            num_labels = struct.unpack(">I", f.read(4))[0]
            labels = np.frombuffer(f.read(), dtype=np.uint8)
            return labels
        else:
            raise ValueError(f"Unknown magic number {magic_number} in {filename}")

# Function to load FashionMNIST from raw .gz files
def load_fashion_mnist_raw(train_image_file, train_label_file, test_image_file, test_label_file):
    # Read in the raw data
    X_train = read_idx_file(train_image_file)  # Load train images
    y_train = read_idx_file(train_label_file)  # Load train labels
    X_test = read_idx_file(test_image_file)    # Load test images
    y_test = read_idx_file(test_label_file)    # Load test labels
    
    # Normalize images to [-1, 1]
    # X_train = X_train / 255.0 * 2.0 - 1.0
    # X_test = X_test / 255.0 * 2.0 - 1.0
    
    return X_train, y_train, X_test, y_test


def generate_bars_and_stripes(n_samples, height, width, noise_std):

    X = np.ones([n_samples, 1, height, width]) * -1
    y = np.zeros([n_samples])

    rng = np.random.default_rng(seed=0)

    for i in range(len(X)):
        if rng.random() > 0.5:
            rows = np.where(rng.random(height) > 0.5)[0]
            X[i, 0, rows, :] = 1.0
            y[i] = +1
        else:
            columns = np.where(rng.random(width) > 0.5)[0]
            X[i, 0, :, columns] = 1.0
            y[i] = 0 #-1
        # X[i, 0] = X[i, 0] + np.random.normal(0, noise_std, size=X[i, 0].shape)

    return X, y

def return_data_scale_factors(config):

    if config["dataset"] == "bars":
        
        data, labels = generate_bars_and_stripes(300, 4, 4, 0.0)

        # print(data[0])
        data = np.reshape(data, (300, 16))
        # print(data[0])
        # print(labels[0])
        # print(np.shape(data))

        train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

        scaled_train_data = np.zeros_like(train_data)
        scaled_test_data = np.zeros_like(test_data)

        if config["encoding_scaler"] == "normed":

            for i, image in enumerate(train_data):
                scaled_train_data[i, :] = image / np.linalg.norm(image)

            for i, image in enumerate(test_data):
                scaled_test_data[i, :] = image / np.linalg.norm(image)

        elif config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(2**config["num_qubits"]-1)))

            scaled_train_data[:, 1:] = scaler.fit_transform(train_data[:, 1:])
            scaled_train_data[:,0] = np.sqrt(1 - np.sum(scaled_train_data[:,1:]**2, axis=1))

            scaled_test_data[:,1:] = scaler.transform(test_data[:, 1:])
            scaled_test_data[:,0] = np.sqrt(1 - np.sum(scaled_test_data[:,1:]**2, axis=1))
        
        elif config["encoding_scaler"] == "angle":

            scaler = MinMaxScaler(feature_range=(0, np.pi))

            scaled_train_data = scaler.fit_transform(train_data)
            scaled_test_data = scaler.transform(test_data)
        else:
            exit("Invalid envoding scaler chosen")

    if config['dataset'] == "digits" or config["dataset"] == "NIST":

        digit_data = load_digits()

        labels = digit_data.target
        data = np.empty((len(labels), 64))

        for i in range(len(labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported")# data[i,:] = z_ordering(digit_data.images[i])
            else:
                data[i,:] = np.reshape(digit_data.images[i], 64, order=config["encoding_order"])

        mask = np.isin(labels, config["classes"]) # extract only classes specified by config

        data = data[mask, :]
        labels = labels[mask]

        train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        if config["encoding_scaler"] == "whiten":
            scaler =StandardScaler()
            train_data = scaler.fit_transform(train_data)
            test_data = scaler.transform(test_data)

        # scaled_train_data = np.zeros_like(train_data)
        # for i, image in enumerate(train_data):
        #     scaled_train_data[i, :] = image / np.linalg.norm(image)

        # scaled_test_data = np.zeros_like(test_data)
        # for i, image in enumerate(test_data):
        #     scaled_test_data[i, :] = image / np.linalg.norm(image)

        scaled_train_data = np.zeros_like(train_data)
        scaled_test_data = np.zeros_like(test_data)

        if config["encoding_scaler"] == "normed":

            for i, image in enumerate(train_data):
                scaled_train_data[i, :] = image / np.linalg.norm(image)

            for i, image in enumerate(test_data):
                scaled_test_data[i, :] = image / np.linalg.norm(image)

        elif config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(2**config["num_qubits"]-1)))

            scaled_train_data[:, 1:] = scaler.fit_transform(train_data[:, 1:])
            scaled_train_data[:,0] = np.sqrt(1 - np.sum(scaled_train_data[:,1:]**2, axis=1))

            scaled_test_data[:,1:] = scaler.transform(test_data[:, 1:])
            scaled_test_data[:,0] = np.sqrt(1 - np.sum(scaled_test_data[:,1:]**2, axis=1))
        
        elif config["encoding_scaler"] == "angle":

            scaler = MinMaxScaler(feature_range=(0, np.pi))

            scaled_train_data = scaler.fit_transform(train_data)
            scaled_test_data = scaler.transform(test_data)
        else:
            exit("Invalid envoding scaler chosen")

    if config['dataset'] == "MNIST":

        train_images = load_mnist_images("data/MNIST/train-images-idx3-ubyte.gz")
        train_labels = load_mnist_labels("data/MNIST/train-labels-idx1-ubyte.gz")
        test_images = load_mnist_images('data/MNIST/t10k-images-idx3-ubyte.gz')
        test_labels = load_mnist_labels('data/MNIST/t10k-labels-idx1-ubyte.gz')

        # train_images = train_images / 255
        # test_images = test_images / 255


        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        print(np.shape(train_labels))
        train_images = train_images[train_mask]
        train_labels = train_labels[train_mask]

        test_images = test_images[test_mask]
        test_labels = test_labels[test_mask]


        train_data = np.empty((len(train_images), len(train_images[0])**2 ))
        test_data = np.empty((len(test_images), len(test_images[0])**2 ))

        for i in range(len(train_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                train_data[i,:] = np.reshape(train_images[i], len(train_images[0])**2 , order=config["encoding_order"])

        for i in range(len(test_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                test_data[i,:] = np.reshape(test_images[i], len(test_images[0])**2 , order=config["encoding_order"])

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1


        scaled_train_data = np.zeros((len(train_labels), 2**config["num_qubits"]))
        scaled_test_data = np.zeros((len(test_labels), 2**config["num_qubits"]))

        if config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(784)), clip=True)

            scaled_train_data[:, 0:784] = scaler.fit_transform(train_data)
            scaled_train_data[:,-1] = np.sqrt(1 - np.sum(scaled_train_data[:, 0:784]**2, axis=1))

            scaled_test_data[:, 0:784] = scaler.transform(test_data)
            scaled_test_data[:,-1] = np.sqrt(1 - np.sum(scaled_test_data[:, 0:784]**2, axis=1))

        scaled_train_data = scaled_train_data[0:200, :]
        scaled_test_data = scaled_test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]

    if config['dataset'] == "Fashion":

        train_image_file = "data/FashionMNIST/train-images-idx3-ubyte-2.gz"
        train_label_file = "data/FashionMNIST/train-labels-idx1-ubyte-2.gz"
        test_image_file = "data/FashionMNIST/t10k-images-idx3-ubyte-2.gz"
        test_label_file = "data/FashionMNIST/t10k-labels-idx1-ubyte-2.gz"

        # Load FashionMNIST data
        train_images, train_labels, test_images, test_labels = load_fashion_mnist_raw(
            train_image_file, train_label_file, test_image_file, test_label_file
        )

        # train_images = load_mnist_images("data/FashionMNIST/train-images-idx3-ubyte.gz")
        # train_labels = load_mnist_labels("data/FashionMNIST/train-labels-idx1-ubyte.gz")
        # test_images = load_mnist_images( 'data/FashionMNIST/t10k-images-idx3-ubyte.gz')
        # test_labels = load_mnist_labels( 'data/FashionMNIST/t10k-labels-idx1-ubyte.gz')

        # train_images = train_images / 255
        # test_images = test_images / 255

        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        print(np.shape(train_labels))
        train_images = train_images[train_mask]
        train_labels = train_labels[train_mask]

        test_images = test_images[test_mask]
        test_labels = test_labels[test_mask]


        train_data = np.empty((len(train_images), len(train_images[0])**2 ))
        test_data = np.empty((len(test_images), len(test_images[0])**2 ))

        for i in range(len(train_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                train_data[i,:] = np.reshape(train_images[i], len(train_images[0])**2 , order=config["encoding_order"])

        for i in range(len(test_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                test_data[i,:] = np.reshape(test_images[i], len(test_images[0])**2 , order=config["encoding_order"])

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        scaled_train_data = np.zeros((len(train_labels), 2**config["num_qubits"]))
        scaled_test_data = np.zeros((len(test_labels), 2**config["num_qubits"]))

        if config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(784)), clip=True)

            scaled_train_data[:, 0:784] = scaler.fit_transform(train_data)
            scaled_train_data[:,-1] = np.sqrt(1 - np.sum(scaled_train_data[:, 0:784]**2, axis=1))

            scaled_test_data[:, 0:784] = scaler.transform(test_data)
            scaled_test_data[:,-1] = np.sqrt(1 - np.sum(scaled_test_data[:, 0:784]**2, axis=1))

        scaled_train_data = scaled_train_data[0:200, :]
        scaled_test_data = scaled_test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]
        
    if config['dataset'] == "CIFAR":

        num_features = 3*1024

        meta_data = unpickle("data/CIFAR/batches.meta")

        print(meta_data[b"label_names"][config["classes"][0]], meta_data[b"label_names"][config["classes"][1]])

        train_data_dict = unpickle("data/CIFAR/data_batch_1")

        train_data = np.array(train_data_dict[b"data"])
        train_labels = np.array(train_data_dict[b"labels"])

        del train_data_dict

        test_data_dict = unpickle("data/CIFAR/test_batch")
        
        test_data = np.array(test_data_dict[b"data"])
        test_labels = np.array(test_data_dict[b"labels"])

        del test_data_dict

        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        train_data = train_data[train_mask, :]
        train_labels = train_labels[train_mask]

        test_data = test_data[test_mask, :]
        test_labels = test_labels[test_mask]

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        scaled_train_data = np.zeros((len(train_labels), 2**config["num_qubits"]))
        scaled_test_data = np.zeros((len(test_labels), 2**config["num_qubits"]))

        if config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(num_features)), clip=True)

            scaled_train_data[:, 0:num_features] = scaler.fit_transform(train_data)
            scaled_train_data[:,-1] = np.sqrt(1 - np.sum(scaled_train_data[:, 0:num_features]**2, axis=1))

            scaled_test_data[:, 0:num_features] = scaler.transform(test_data)
            scaled_test_data[:,-1] = np.sqrt(1 - np.sum(scaled_test_data[:, 0:num_features]**2, axis=1))

        scaled_train_data = scaled_train_data[0:200, :]
        scaled_test_data = scaled_test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]
        



    return scaler.scale_, scaler.min_


def return_scaled_data(config):

    if config["dataset"] == "bars":
        
        data, labels = generate_bars_and_stripes(300, 4, 4, 0.0)

        # print(data[0])
        data = np.reshape(data, (300, 16))
        # print(data[0])
        # print(labels[0])
        # print(np.shape(data))

        train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

        scaled_train_data = np.zeros_like(train_data)
        scaled_test_data = np.zeros_like(test_data)

        if config["encoding_scaler"] == "normed":

            for i, image in enumerate(train_data):
                scaled_train_data[i, :] = image / np.linalg.norm(image)

            for i, image in enumerate(test_data):
                scaled_test_data[i, :] = image / np.linalg.norm(image)

        elif config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(2**config["num_qubits"]-1)))

            scaled_train_data[:, 1:] = scaler.fit_transform(train_data[:, 1:])
            scaled_train_data[:,0] = np.sqrt(1 - np.sum(scaled_train_data[:,1:]**2, axis=1))

            scaled_test_data[:,1:] = scaler.transform(test_data[:, 1:])
            scaled_test_data[:,0] = np.sqrt(1 - np.sum(scaled_test_data[:,1:]**2, axis=1))
        
        elif config["encoding_scaler"] == "angle":

            scaler = MinMaxScaler(feature_range=(0, np.pi))

            scaled_train_data = scaler.fit_transform(train_data)
            scaled_test_data = scaler.transform(test_data)
        else:
            exit("Invalid envoding scaler chosen")

    if config['dataset'] == "digits" or config["dataset"] == "NIST":

        digit_data = load_digits()

        labels = digit_data.target
        data = np.empty((len(labels), 64))

        for i in range(len(labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported")# data[i,:] = z_ordering(digit_data.images[i])
            else:
                data[i,:] = np.reshape(digit_data.images[i], 64, order=config["encoding_order"])

        mask = np.isin(labels, config["classes"]) # extract only classes specified by config

        data = data[mask, :]
        labels = labels[mask]

        train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        if config["encoding_scaler"] == "whiten":
            scaler =StandardScaler()
            train_data = scaler.fit_transform(train_data)
            test_data = scaler.transform(test_data)

        # scaled_train_data = np.zeros_like(train_data)
        # for i, image in enumerate(train_data):
        #     scaled_train_data[i, :] = image / np.linalg.norm(image)

        # scaled_test_data = np.zeros_like(test_data)
        # for i, image in enumerate(test_data):
        #     scaled_test_data[i, :] = image / np.linalg.norm(image)

        scaled_train_data = np.zeros_like(train_data)
        scaled_test_data = np.zeros_like(test_data)

        if config["encoding_scaler"] == "normed":

            for i, image in enumerate(train_data):
                scaled_train_data[i, :] = image / np.linalg.norm(image)

            for i, image in enumerate(test_data):
                scaled_test_data[i, :] = image / np.linalg.norm(image)

        elif config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(2**config["num_qubits"]-1)))

            scaled_train_data[:, 1:] = scaler.fit_transform(train_data[:, 1:])
            scaled_train_data[:,0] = np.sqrt(1 - np.sum(scaled_train_data[:,1:]**2, axis=1))

            scaled_test_data[:,1:] = scaler.transform(test_data[:, 1:])
            scaled_test_data[:,0] = np.sqrt(1 - np.sum(scaled_test_data[:,1:]**2, axis=1))
        
        elif config["encoding_scaler"] == "angle":

            scaler = MinMaxScaler(feature_range=(0, np.pi))

            scaled_train_data = scaler.fit_transform(train_data)
            scaled_test_data = scaler.transform(test_data)
        else:
            exit("Invalid envoding scaler chosen")

    if config['dataset'] == "MNIST":

        train_images = load_mnist_images("data/MNIST/train-images-idx3-ubyte.gz")
        train_labels = load_mnist_labels("data/MNIST/train-labels-idx1-ubyte.gz")
        test_images = load_mnist_images('data/MNIST/t10k-images-idx3-ubyte.gz')
        test_labels = load_mnist_labels('data/MNIST/t10k-labels-idx1-ubyte.gz')

        # train_images = train_images / 255
        # test_images = test_images / 255


        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        print(np.shape(train_labels))
        train_images = train_images[train_mask]
        train_labels = train_labels[train_mask]

        test_images = test_images[test_mask]
        test_labels = test_labels[test_mask]


        train_data = np.empty((len(train_images), len(train_images[0])**2 ))
        test_data = np.empty((len(test_images), len(test_images[0])**2 ))

        for i in range(len(train_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                train_data[i,:] = np.reshape(train_images[i], len(train_images[0])**2 , order=config["encoding_order"])

        for i in range(len(test_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                test_data[i,:] = np.reshape(test_images[i], len(test_images[0])**2 , order=config["encoding_order"])

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1


        scaled_train_data = np.zeros((len(train_labels), 2**config["num_qubits"]))
        scaled_test_data = np.zeros((len(test_labels), 2**config["num_qubits"]))

        if config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(784)), clip=True)

            scaled_train_data[:, 0:784] = scaler.fit_transform(train_data)
            scaled_train_data[:,-1] = np.sqrt(1 - np.sum(scaled_train_data[:, 0:784]**2, axis=1))

            scaled_test_data[:, 0:784] = scaler.transform(test_data)
            scaled_test_data[:,-1] = np.sqrt(1 - np.sum(scaled_test_data[:, 0:784]**2, axis=1))

        scaled_train_data = scaled_train_data[0:200, :]
        scaled_test_data = scaled_test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]

    if config['dataset'] == "Fashion":

        train_image_file = "data/FashionMNIST/train-images-idx3-ubyte-2.gz"
        train_label_file = "data/FashionMNIST/train-labels-idx1-ubyte-2.gz"
        test_image_file = "data/FashionMNIST/t10k-images-idx3-ubyte-2.gz"
        test_label_file = "data/FashionMNIST/t10k-labels-idx1-ubyte-2.gz"

        # Load FashionMNIST data
        train_images, train_labels, test_images, test_labels = load_fashion_mnist_raw(
            train_image_file, train_label_file, test_image_file, test_label_file
        )

        # train_images = load_mnist_images("data/FashionMNIST/train-images-idx3-ubyte.gz")
        # train_labels = load_mnist_labels("data/FashionMNIST/train-labels-idx1-ubyte.gz")
        # test_images = load_mnist_images( 'data/FashionMNIST/t10k-images-idx3-ubyte.gz')
        # test_labels = load_mnist_labels( 'data/FashionMNIST/t10k-labels-idx1-ubyte.gz')

        # train_images = train_images / 255
        # test_images = test_images / 255

        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        print(np.shape(train_labels))
        train_images = train_images[train_mask]
        train_labels = train_labels[train_mask]

        test_images = test_images[test_mask]
        test_labels = test_labels[test_mask]


        train_data = np.empty((len(train_images), len(train_images[0])**2 ))
        test_data = np.empty((len(test_images), len(test_images[0])**2 ))

        for i in range(len(train_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                train_data[i,:] = np.reshape(train_images[i], len(train_images[0])**2 , order=config["encoding_order"])

        for i in range(len(test_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                test_data[i,:] = np.reshape(test_images[i], len(test_images[0])**2 , order=config["encoding_order"])

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        scaled_train_data = np.zeros((len(train_labels), 2**config["num_qubits"]))
        scaled_test_data = np.zeros((len(test_labels), 2**config["num_qubits"]))

        if config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(784)), clip=True)

            scaled_train_data[:, 0:784] = scaler.fit_transform(train_data)
            scaled_train_data[:,-1] = np.sqrt(1 - np.sum(scaled_train_data[:, 0:784]**2, axis=1))

            scaled_test_data[:, 0:784] = scaler.transform(test_data)
            scaled_test_data[:,-1] = np.sqrt(1 - np.sum(scaled_test_data[:, 0:784]**2, axis=1))

        scaled_train_data = scaled_train_data[0:200, :]
        scaled_test_data = scaled_test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]
        
    if config['dataset'] == "CIFAR":

        num_pixels = 1024

        meta_data = unpickle("data/CIFAR/batches.meta")

        print(meta_data[b"label_names"][config["classes"][0]], meta_data[b"label_names"][config["classes"][1]])

        train_data_dict = unpickle("data/CIFAR/data_batch_1")

        train_data = np.array(train_data_dict[b"data"], dtype=np.float64)
        train_labels = np.array(train_data_dict[b"labels"], dtype=np.float64)

        del train_data_dict

        test_data_dict = unpickle("data/CIFAR/test_batch")

        test_data = np.array(test_data_dict[b"data"])
        test_labels = np.array(test_data_dict[b"labels"])

        del test_data_dict

        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        train_data = train_data[train_mask, :]
        train_labels = train_labels[train_mask]

        test_data = test_data[test_mask, :]
        test_labels = test_labels[test_mask]

        # convert to greyscale by averaging
        grey_train_data = (train_data[:, 0:num_pixels] + train_data[:, num_pixels:2*num_pixels] + train_data[:, 2*num_pixels:3*num_pixels]) / 3
        grey_test_data = (test_data[:,0:num_pixels] + test_data[:,num_pixels:2*num_pixels] + test_data[:,2*num_pixels:3*num_pixels]) / 3


        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        scaled_train_data = np.zeros((len(train_labels), 2**config["num_qubits"]))
        scaled_test_data = np.zeros((len(test_labels), 2**config["num_qubits"]))

        if config["encoding_scaler"] == "overflow":

            scaler = MinMaxScaler(feature_range=(0, 1/np.sqrt(num_pixels)), clip=True)

            scaled_train_data[:, 0:num_pixels] = scaler.fit_transform(grey_train_data)
            scaled_train_data[:,-1] = np.sqrt(1 - np.sum(scaled_train_data[:, 0:-1]**2, axis=1))

            scaled_test_data[:, 0:num_pixels] = scaler.transform(grey_test_data)
            scaled_test_data[:,-1] = np.sqrt(1 - np.sum(scaled_test_data[:, 0:-1]**2, axis=1))

        scaled_train_data = scaled_train_data[0:200, :]
        scaled_test_data = scaled_test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]

        



    return scaled_train_data, scaled_test_data, train_labels, test_labels

def return_unscaled_data(config):

    if config["dataset"] == "bars":
        
        data, labels = generate_bars_and_stripes(300, 4, 4, 0.0)

        # print(data[0])
        data = np.reshape(data, (300, 16))
        # print(data[0])
        # print(labels[0])
        # print(np.shape(data))

        train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

        
    if config['dataset'] == "digits" or config["dataset"] == "NIST":

        digit_data = load_digits()

        labels = digit_data.target
        data = np.empty((len(labels), 64))

        for i in range(len(labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported")# data[i,:] = z_ordering(digit_data.images[i])
            else:
                data[i,:] = np.reshape(digit_data.images[i], 64, order=config["encoding_order"])

        mask = np.isin(labels, config["classes"]) # extract only classes specified by config

        data = data[mask, :]
        labels = labels[mask]

        train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1


        train_data = train_data[0:200, :]
        test_data = test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]

    if config['dataset'] == "MNIST":

        train_images = load_mnist_images("data/MNIST/train-images-idx3-ubyte.gz")
        train_labels = load_mnist_labels("data/MNIST/train-labels-idx1-ubyte.gz")
        test_images = load_mnist_images('data/MNIST/t10k-images-idx3-ubyte.gz')
        test_labels = load_mnist_labels('data/MNIST/t10k-labels-idx1-ubyte.gz')

        # train_images = train_images / 255
        # test_images = test_images / 255


        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        print(np.shape(train_labels))
        train_images = train_images[train_mask]
        train_labels = train_labels[train_mask]

        test_images = test_images[test_mask]
        test_labels = test_labels[test_mask]


        train_data = np.empty((len(train_images), len(train_images[0])**2 ))
        test_data = np.empty((len(test_images), len(test_images[0])**2 ))

        for i in range(len(train_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                train_data[i,:] = np.reshape(train_images[i], len(train_images[0])**2 , order=config["encoding_order"])

        for i in range(len(test_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                test_data[i,:] = np.reshape(test_images[i], len(test_images[0])**2 , order=config["encoding_order"])

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        train_data = train_data[0:200, :]
        test_data = test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]



    if config['dataset'] == "Fashion":

        train_image_file = "data/FashionMNIST/train-images-idx3-ubyte-2.gz"
        train_label_file = "data/FashionMNIST/train-labels-idx1-ubyte-2.gz"
        test_image_file = "data/FashionMNIST/t10k-images-idx3-ubyte-2.gz"
        test_label_file = "data/FashionMNIST/t10k-labels-idx1-ubyte-2.gz"

        # Load FashionMNIST data
        train_images, train_labels, test_images, test_labels = load_fashion_mnist_raw(
            train_image_file, train_label_file, test_image_file, test_label_file
        )

        # train_images = load_mnist_images("data/FashionMNIST/train-images-idx3-ubyte.gz")
        # train_labels = load_mnist_labels("data/FashionMNIST/train-labels-idx1-ubyte.gz")
        # test_images = load_mnist_images( 'data/FashionMNIST/t10k-images-idx3-ubyte.gz')
        # test_labels = load_mnist_labels( 'data/FashionMNIST/t10k-labels-idx1-ubyte.gz')

        # train_images = train_images / 255
        # test_images = test_images / 255

        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        print(np.shape(train_labels))
        train_images = train_images[train_mask]
        train_labels = train_labels[train_mask]

        test_images = test_images[test_mask]
        test_labels = test_labels[test_mask]


        train_data = np.empty((len(train_images), len(train_images[0])**2 ))
        test_data = np.empty((len(test_images), len(test_images[0])**2 ))

        for i in range(len(train_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                train_data[i,:] = np.reshape(train_images[i], len(train_images[0])**2 , order=config["encoding_order"])

        for i in range(len(test_labels)):

            if config["encoding_order"] == "Z":
                exit("Z not supported") # data[i,:] = z_ordering(digit_data.images[i])
            else:
                test_data[i,:] = np.reshape(test_images[i], len(test_images[0])**2 , order=config["encoding_order"])

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        train_data = train_data[0:200, :]
        test_data = test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]

        
    if config['dataset'] == "CIFAR":

        num_features = 3*1024

        meta_data = unpickle("data/CIFAR/batches.meta")

        print(meta_data[b"label_names"][config["classes"][0]], meta_data[b"label_names"][config["classes"][1]])

        train_data_dict = unpickle("data/CIFAR/data_batch_1")

        train_data = np.array(train_data_dict[b"data"])
        train_labels = np.array(train_data_dict[b"labels"])

        del train_data_dict

        test_data_dict = unpickle("data/CIFAR/test_batch")
        
        test_data = np.array(test_data_dict[b"data"])
        test_labels = np.array(test_data_dict[b"labels"])

        del test_data_dict

        train_mask = np.isin(train_labels, config["classes"])
        test_mask = np.isin(test_labels, config["classes"])

        train_data = train_data[train_mask, :]
        train_labels = train_labels[train_mask]

        test_data = test_data[test_mask, :]
        test_labels = test_labels[test_mask]

        train_labels[train_labels==config["classes"][0]] = 0
        train_labels[train_labels==config["classes"][1]] = 1

        test_labels[test_labels==config["classes"][0]] = 0
        test_labels[test_labels==config["classes"][1]] = 1

        train_data = train_data[0:200, :]
        test_data = test_data[0:200, :]

        train_labels = train_labels[0:200]
        test_labels = test_labels[0:200]
        



    return train_data, test_data, train_labels, test_labels
