import numpy as np


def distribute_objects_optimized(node_index_list, graph_sizes, num_buckets, batch_size):
    # node_index_list = list of node indices from each graph
    # graph_sizes: list of graph sizes, same length/order as node_index_list

    # Sort objects by size
    sorted_objects = sorted(
        zip(graph_sizes, np.arange(len(graph_sizes)), node_index_list),
        key=lambda x: x[0],
        reverse=True,
    )

    # Initialize buckets
    buckets = [
        {"graph_index": [], "node_index": [], "free": batch_size}
        for _ in range(num_buckets)
    ]

    # Distribute objects
    bucket_index = 0
    direction = 1  # 1 for left-to-right, -1 for right-to-left

    for graph_size, graph_idx, nodes_sampled_from_graph in sorted_objects:
        num_nodes = len(nodes_sampled_from_graph)
        while num_nodes > 0:
            # Find the next bucket where the object can be added
            while buckets[bucket_index]["free"] == 0:  # Skip full buckets
                bucket_index += direction
                if bucket_index >= num_buckets or bucket_index < 0:
                    # Change direction when the end or start is reached
                    direction *= -1
                    bucket_index += direction

            # Add the object to the bucket
            num_nodes_to_add = min(num_nodes, buckets[bucket_index]["free"])
            buckets[bucket_index]["graph_index"].extend([graph_idx] * num_nodes_to_add)
            buckets[bucket_index]["node_index"].extend(
                nodes_sampled_from_graph[:num_nodes_to_add]
            )
            buckets[bucket_index]["free"] -= num_nodes_to_add

            num_nodes -= num_nodes_to_add
            nodes_sampled_from_graph = nodes_sampled_from_graph[num_nodes_to_add:]

            bucket_index += direction
            if bucket_index >= num_buckets or bucket_index < 0:
                # Change direction when the end or start is reached
                direction *= -1
                bucket_index += direction
    return buckets


batch_size = 4096
accum_gradient_steps = 2
num_nodes = 8
num_buckets = num_nodes * accum_gradient_steps


all_graph_size = np.array(
    [
        78,
        105,
        239,
        580,
        603,
        714,
        894,
        1147,
        1443,
        1797,
        2538,
        2631,
        2788,
        3117,
        3930,
        4275,
        4545,
        4574,
        5000,
        5698,
        5879,
        11331,
        11830,
        12246,
        13482,
        16968,
        19407,
        22620,
        24460,
        44625,
        153431,
        153932,
        253176,
    ]
)
# randomly sample from object_size_list
graph_index = np.random.choice(
    len(all_graph_size),
    num_buckets * batch_size,
    replace=True,
    p=all_graph_size / np.sum(all_graph_size),
)
unique_graph_index, nodes_per_graph = np.unique(graph_index, return_counts=True)

node_index_list = []
for i in range(len(unique_graph_index)):
    node_index_list.append(
        np.random.choice(all_graph_size[i], nodes_per_graph[i], replace=False)
    )

graph_sizes = all_graph_size[unique_graph_index]

result = distribute_objects_optimized(
    node_index_list, graph_sizes, num_buckets, batch_size
)

# for i, meta in enumerate(result):
#     unique_graphs = np.unique(meta["graph_index"])
#     graph_sizes_local = [graph_sizes[graph_idx] for graph_idx in unique_graphs]

#     print(f"bucket: {i}, num_graphs: {len(unique_graphs)}, sizes: {graph_sizes_local}")
