#!/usr/bin/env python3
import os
from ruamel.yaml import YAML

# Folder containing YAML files
FOLDER = "."

yaml = YAML()
yaml.preserve_quotes = True
yaml.indent(sequence=4, offset=2)
yaml.width = 4096  # avoid line breaks

def update_yaml_file(filepath):
    with open(filepath, "r") as f:
        data = yaml.load(f)

    # --- 1️⃣ Modify search_space ---
    if "search_space" in data:
        ss = data["search_space"]

        # Replace n_loops
        ss["n_loops"] = [10, 15]

        # Remove batch_add / batch_remove
        for key in ["batch_add", "batch_remove"]:
            if key in ss:
                del ss[key]

        # Add tau after n_loops (preserve logical order)
        # We rebuild dict manually to ensure placement
        if "tau" not in ss:
            new_ss = type(ss)()
            for k, v in ss.items():
                new_ss[k] = v
                if k == "n_loops":
                    new_ss["tau"] = [10, 100]
            ss.clear()
            ss.update(new_ss)

    # --- 2️⃣ Modify dataset_parameters.pre_transform ---
    ds = data.get("dataset_parameters", {})
    pre_t = ds.get("pre_transform", {})

    if isinstance(pre_t, dict):
        pre_t["pre_t_class"] = "sdrf"
        pre_t["n_loops"] = 10

        for key in ["batch_add", "batch_remove"]:
            if key in pre_t:
                del pre_t[key]

        # Insert tau after n_loops
        if "tau" not in pre_t:
            new_pre_t = type(pre_t)()
            for k, v in pre_t.items():
                new_pre_t[k] = v
                if k == "n_loops":
                    new_pre_t["tau"] = 100
            pre_t.clear()
            pre_t.update(new_pre_t)

        ds["pre_transform"] = pre_t
        data["dataset_parameters"] = ds

    # --- 3️⃣ Write back file ---
    with open(filepath, "w") as f:
        yaml.dump(data, f)

    print(f"✅ Updated {os.path.basename(filepath)}")

def main():
    for fname in os.listdir(FOLDER):
        if fname.endswith(".yaml"):
            update_yaml_file(os.path.join(FOLDER, fname))

if __name__ == "__main__":
    main()
