import gc

import numpy as np

import glob
from timeit import default_timer as timer

from sktime.utils.data_processing import from_3d_numpy_to_nested
from sktime.transformations.panel.rocket import MiniRocketMultivariate


def measure_transform_times():
    global_times = []
    for filename in sorted(glob.glob(F"Industrial/*.npz")):
        local_times = []
        data = np.load(filename)
        train_x, test_x = data['train_x'].astype(np.float64), data['test_x'].astype(np.float64)
        train_x = from_3d_numpy_to_nested(train_x)

        minirocket = MiniRocketMultivariate(random_state=0)
        minirocket.fit(train_x)
        for i in range(100):
            ind = np.random.choice(test_x.shape[0])
            tr_test = from_3d_numpy_to_nested(test_x[ind:ind + 1])
            gc.collect()
            start = timer()
            X_test_transform = minirocket.transform(tr_test)
            end = timer()
            local_times.append(end - start)
        global_times.append(local_times)
    np.savetxt(F"jetson_minirocket_transform_times.csv", np.array(global_times), delimiter=',')


if __name__ == "__main__":
    measure_transform_times()
