import scipy.io
import os
import numpy as np
from numba import jit
import torch

@jit(nopython=True)
def DTW(x, y, w=0):
    m = x.shape[0]
    n = y.shape[0]
    D = np.zeros((m+1, n+1)) + 1e8
    w = max(w, np.abs(m - n))
    D[0, 0] = 0
    for i in range(1, m + 1):
        for j in range(max(1, i - w), min(n, i + w) + 1):
            cost = np.linalg.norm(x[i - 1] - y[j - 1])
            D[i, j] = cost + min(min(D[i - 1, j], D[i, j - 1]), D[i - 1, j - 1])
    return D[m, n]

def computeDTW(data):
    num_instances = data.shape[0]
    dtw_distances = np.zeros((num_instances, num_instances))

    # Compute DTW distances for all pairs
    for i in range(num_instances):
        for j in range(i + 1, num_instances):  # Avoid duplicate calculations (only upper triangle)
            dtw_distances[i, j] = DTW(data[i], data[j])
            dtw_distances[j, i] = dtw_distances[i, j]  # Symmetric matrix

    return dtw_distances

def save_files(x_data, y_data, dtw_distances, dataset, subject):
    dir = f'../datasets/eeg/{dataset}'
    if not os.path.exists(f'{dir}'):
        os.makedirs(f'{dir}')

    np.save(f'{dir}/subject_{subject}_dtw.npy', dtw_distances)
    np.save(f'{dir}/subject_{subject}_x.npy', x_data)
    np.savetxt(f'{dir}/subject_{subject}_y.txt', y_data, fmt='%d')

subjects = {'bci': [1,2,3,4,5,6,7,8,9],
            'mamem': [1,2,3,4,5,6,7,8,9,10,11],
            'bcicha': [2,6,7,11,12,13,14,16,17,18,20,21,22,23,24,26]}

for dataset in ['bci', 'mamem', 'bcicha']:

    if dataset == 'bci':
        for subject in subjects[dataset]:
            print(f'Creating files for {dataset}: Subject {subject}')
            mat_file_train = scipy.io.loadmat(f'BCICIV_2a_mat/BCIC_S0{subject}_T.mat')
            mat_file_eval = scipy.io.loadmat(f'BCICIV_2a_mat/BCIC_S0{subject}_E.mat')

            x_train = mat_file_train['x_train']
            y_train = mat_file_train['y_train']

            x_test = mat_file_eval['x_test']
            y_test = mat_file_eval['y_test']

            x_data = np.concatenate((x_train, x_test), axis=0)
            y_data = np.concatenate((y_train, y_test), axis=0)

            dtw_distances = computeDTW(x_data)

            save_files(x_data, y_data, dtw_distances, dataset, subject)

    elif dataset == 'mamem':
        for subject in subjects[dataset]:
            print(f'Creating files for {dataset}: Subject {subject}')
            mat_file = scipy.io.loadmat('MAMEM/U' + f'{int(subject):03d}' + '.mat')

            x_data = mat_file['x_test']
            y_data = mat_file['y_test']
            dtw_distances = computeDTW(x_data)

            save_files(x_data, y_data, dtw_distances, dataset, subject)

    elif dataset == 'bcicha':
        for subject in subjects[dataset]:
            print(f'Creating files for {dataset}: Subject {subject}')
            mat_file = scipy.io.loadmat(f'BCIcha/Data_S{int(subject):02d}_Sess' + '.mat')

            x_data = mat_file['x_test']
            y_data = mat_file['y_test']
            dtw_distances = computeDTW(x_data)

            save_files(x_data, y_data, dtw_distances, dataset, subject)

print('Done!')
