import zipfile, json, os
from sklearn.model_selection import train_test_split

original_zip_path = "data/afhqv2-64x64.zip"
train_zip_path = "data/afhqv2-64x64_train.zip"
test_zip_path = "data/afhqv2-64x64_test.zip"
split_ratio = 0.9  # 90% train, 10% test

# Load dataset.json and all filenames
with zipfile.ZipFile(original_zip_path, "r") as zin:
    labels = json.loads(zin.read("dataset.json"))["labels"]
    image_fnames = [f for f in zin.namelist() if f.lower().endswith((".png", ".jpg"))]

# Split
train_files, test_files = train_test_split(
    image_fnames, train_size=split_ratio, random_state=42
)


# Helper to write zip
def write_split_zip(split_zip_path, split_files):
    with zipfile.ZipFile(split_zip_path, "w") as zout, zipfile.ZipFile(
        original_zip_path, "r"
    ) as zin:
        new_labels = {}
        for fname in split_files:
            with zin.open(fname) as f:
                zout.writestr(fname, f.read())
            if fname in labels:
                new_labels[fname] = labels[fname]
        zout.writestr("dataset.json", json.dumps({"labels": list(new_labels.items())}))


write_split_zip(train_zip_path, train_files)
write_split_zip(test_zip_path, test_files)
