import tensorflow as tf
import tensorflow_datasets as tfds

# Load the CelebA dataset with labels
def load_celeba_dataset():
    dataset, info = tfds.load('celeb_a', split='train', with_info=True)
    return dataset, info



# Preprocess the data (optional, depending on your specific needs)

def preprocess_image(data):
    return tf.clip_by_value(float(data["image"]) / 255.0, 0.0, 1.0)
def preprocess_label(data):
    return tf.cast(list(data["attributes"].values()), tf.float32)
# Load the dataset
celeba_dataset, celeba_info = load_celeba_dataset()

# Preprocess the data
celeba_data = celeba_dataset.map(lambda data: {'image': preprocess_image(data), 'label': preprocess_label(data)}, num_parallel_calls=tf.data.AUTOTUNE)

# Print dataset information
print("Dataset info:", celeba_info)
print("Number of samples:", celeba_info.splits['train'].num_examples)

once_attributes = {}


# Accessing the data
for image, attributes in celeba_dataset.take(4):
    print("Image shape:", image.shape)
    print("Attribute:", type(attributes))
    once_attributes = attributes
    break
if once_attributes['Attractive'].numpy() == 1:
    print("This person is attractive!")
else:
    print("This person is not attractive!")

for i in once_attributes.keys():
    print(i, ": ", once_attributes[i].numpy())

attribute_values = []
for i in once_attributes.values():
    attribute_values.append(i)

attribute_values = tf.cast(list(data['attributes']), tf.float32)