import json
import pyRDDLGym
from regawa.rddl import RDDLModel, RDDLGroundedModel
import numpy as np
import sys


def find_close_split(l: list, iterations: int):
    lowest_distance = 100000000
    store1 = None
    store2 = None
    for i in range(iterations):
        indexes = np.arange(0, len(l))
        np.random.shuffle(indexes)
        s1, s2 = l[indexes[:5]], l[indexes[5:]]

        s1_mean = s1.mean()
        s2_mean = s2.mean()

        distance = abs(s1_mean - s2_mean)
        if distance < lowest_distance:
            lowest_distance = distance
            store1 = indexes[:5].copy()
            store2 = indexes[5:].copy()
            # print(s1, s2)
            # print(s1_mean, s2_mean)

    return store1, store2


def get_split(l: list):
    indexes = np.arange(0, len(l))
    np.random.shuffle(indexes)
    store1 = indexes[:5].copy()
    store2 = indexes[5:].copy()
    return store1, store2


def main():
    d = sys.argv[1]

    data: list[int] = []
    for i in range(1, 11):
        env = pyRDDLGym.make(d, str(i))
        model = RDDLModel(env.model)
        # ground = RDDLGroundedModel(model.model)
        num_objects = len(model._obj_to_type)
        # print(" & ".join(map(str, [d, i, len(model._obj_to_type), len(ground.groundings)])))
        data.append(num_objects)

    s1, s2 = get_split(np.array(data))
    splits = {
        "domain": d,
        "train": s1.tolist(),
        "test": s2.tolist(),
    }

    print(json.dumps(splits))
