import h5py
import numpy as np

def map_continuous_to_discrete(labels):
    floor_hue = np.round(labels[:, 0] * 10).astype(int)
    wall_hue = np.round(labels[:, 1] * 10).astype(int)
    object_hue = np.round(labels[:, 2] * 10).astype(int)

    scale_vals = np.linspace(0.75, 1.25, 8)
    scale_idx = np.argmin(np.abs(labels[:, 3][:, None] - scale_vals[None, :]), axis=1)

    shape = labels[:, 4].astype(int)

    orientation_vals = np.linspace(-30, 30, 15)
    orientation_idx = np.argmin(np.abs(labels[:, 5][:, None] - orientation_vals[None, :]), axis=1)

    discrete_labels = np.stack([
        floor_hue,
        wall_hue,
        object_hue,
        scale_idx,
        shape,
        orientation_idx
    ], axis=1)

    return discrete_labels

# ===== CONFIGURATION =====
use_subset = True  # Set this to False to use 100% of the dataset
subset_fraction = 0.10
train_split_ratio = 0.9

input_path = "/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes.h5"
train_output_path = "/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes_train_subset.h5" if use_subset else "/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes_train.h5"
val_output_path = "/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes_val_subset.h5" if use_subset else "/scratch/user/repos/ssl/solo-learn-gaussianization/datasets/3dshapes/3dshapes_val.h5"

# ===== LOAD DATA =====
with h5py.File(input_path, "r") as f:
    images = f["images"][:]
    labels = f["labels"][:]

print("images and labels loaded")

# ===== CONVERT TO DISCRETE LABELS =====
discrete_labels = map_continuous_to_discrete(labels)

# ===== SUBSET OPTION =====
num_total = images.shape[0]
if use_subset:
    subset_size = int(num_total * subset_fraction)
    subset_indices = np.random.permutation(num_total)[:subset_size]
    images = images[subset_indices]
    labels = labels[subset_indices]
    discrete_labels = discrete_labels[subset_indices]
    print(f"Using 10% subset: {subset_size} samples")
else:
    print(f"Using full dataset: {num_total} samples")

# ===== SHUFFLE =====
num_samples = images.shape[0]
shuffled_indices = np.random.permutation(num_samples)
images = images[shuffled_indices]
labels = labels[shuffled_indices]
discrete_labels = discrete_labels[shuffled_indices]

# ===== SPLIT =====
split_idx = int(num_samples * train_split_ratio)
train_images, val_images = images[:split_idx], images[split_idx:]
train_labels, val_labels = labels[:split_idx], labels[split_idx:]
train_discrete, val_discrete = discrete_labels[:split_idx], discrete_labels[split_idx:]

# ===== SAVE TRAIN =====
with h5py.File(train_output_path, "w") as f_train:
    f_train.create_dataset("images", data=train_images, compression="gzip")
    # f_train.create_dataset("labels", data=train_labels, compression="gzip")
    # f_train.create_dataset("discrete_labels", data=train_discrete, compression="gzip")
    f_train.create_dataset("labels", data=train_discrete, compression="gzip")

print("Train set saved:", train_output_path)

# ===== SAVE VAL =====
with h5py.File(val_output_path, "w") as f_val:
    f_val.create_dataset("images", data=val_images, compression="gzip")
    # f_val.create_dataset("labels", data=val_labels, compression="gzip")
    # f_val.create_dataset("discrete_labels", data=val_discrete, compression="gzip")
    f_val.create_dataset("labels", data=val_discrete, compression="gzip")

print("Validation set saved:", val_output_path)
