import numpy as np
import pickle


class GetData():
    def __init__(self, n_instance, n_cities):
        self.n_instance = n_instance
        self.n_cities = n_cities

    def generate_instances(self):
        # np.random.seed(2024)
        instance_data = []
        for _ in range(self.n_instance):
            coordinates = np.random.rand(self.n_cities, 2)
            distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
            instance_data.append((coordinates, distances))
        return instance_data


if __name__ == '__main__':
    exp_specs = {
        'train': {'n_cities': [50], 'n_instance': 64},
        'test': {'n_cities': [50, 100, 200], 'n_instance': 64},
    }

    for partition, setup in exp_specs.items():
        data_dict = {}

        for n_city in setup['n_cities']:

            generator = GetData(setup['n_instance'], n_city)
            data_dict[n_city] = generator.generate_instances()

        # Save all data to one file
        with open(f'all_data_{partition}.pkl', 'wb') as f:
            pickle.dump(data_dict, f)
