import numpy as np
import pickle


class GetData():
    def __init__(self, n_instance, n_cities, capacity):
        self.n_instance = n_instance
        self.n_cities = n_cities
        self.capacity = capacity
        self.depot_coords = np.array([[0.5, 0.5]])

    def generate_instances(self):
        np.random.seed(2025)
        instance_data = []
        for _ in range(self.n_instance):
            coordinates = np.concatenate((self.depot_coords, np.random.rand(self.n_cities, 2)), axis=0)
            demands = np.concatenate((np.zeros(1,), np.random.randint(1, 10, size=self.n_cities)))
            distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
            instance_data.append((coordinates, distances, demands, self.capacity))
        return instance_data


if __name__ == '__main__':
    exp_specs = {
        'train': {'n_cities': [50], 'capacity': 50, 'n_instance': 64},
        # 'test': {'n_cities': [50, 100, 200], 'n_instance': 64},
    }

    for partition, setup in exp_specs.items():
        data_dict = {}

        for n_city in setup['n_cities']:

            generator = GetData(setup['n_instance'], n_city, setup['capacity'])
            data_dict[n_city] = generator.generate_instances()

        # Save all data to one file
        with open(f'all_data_{partition}.pkl', 'wb') as f:
            pickle.dump(data_dict, f)
