import numpy as np
import random
import pickle
from itertools import combinations

def allocate_data(data_points, num_models, forget_set):
    # Initialize model allocations
    model_allocations = {i: set() for i in range(num_models)}
    
    # Helper function to get unique combinations
    def unique_combinations(lst, n):
        return list(combinations(lst, n))

    # Allocate forget set
    half_models = num_models // 2
    model_combinations = unique_combinations(range(num_models), half_models)

    if len(forget_set) > len(model_combinations):
        raise ValueError("Too many elements in the forget set for unique allocation.")

    for i, data_point in enumerate(forget_set):
        for model in model_combinations[i]:
            model_allocations[model].add(data_point)

    # Allocate remaining data points
    remaining_data = set(data_points) - set(forget_set)
    for data_point in remaining_data:
        for model in range(num_models):
            model_allocations[model].add(data_point)

    return model_allocations

def allocate_datas(data_points, num_models, forget_set):
    # Initialize model allocations
    # model_allocations = {i: set(data_points) for i in range(num_models)}
    model_allocations = {i: set(forget_set) for i in range(num_models)}
    
    map_to_removed = {}
    # Determine models for each data point in the forget set
    for data_point in forget_set:
        models_to_remove = random.sample(range(num_models), num_models // 2)
        map_to_removed[data_point] = models_to_remove

        for model in models_to_remove:
            model_allocations[model].remove(data_point)


    # sample from data points to add to each model to make it length len(data_points)//2:
    remaining_data = set(data_points) - set(forget_set)
    for model in range(num_models):
        num_to_add = len(data_points) // 2 - len(model_allocations[model])
        if num_to_add > 0:
            model_allocations[model] = model_allocations[model].union(set(random.sample(remaining_data, num_to_add)))

    return model_allocations, map_to_removed


if __name__ == '__main__':
    # Example usage
    data_points = list(range(100))  # 100 data points numbered 0 to 99
    num_models = 8
    forget_set = list(range(10))  # Example forget set

    allocations, mapping = allocate_datas(data_points, num_models, forget_set)
    for key in allocations:
        print(key, len(allocations[key]))
        print(allocations[key])

    for key in mapping:
        print(key, mapping[key])

    with open('partitions.pkl', 'wb') as f:
        pickle.dump(allocations, f)
