'''
This script is used to precompute the similarity matrix for each dataset in the dataset list.
The similarity matrix is computed based on the instance-wise distance between each pair of instances in the dataset.
This saves time when training the model.
'''

import os
import sys
import time
import argparse
import train_eval_utils.utils_data as datautils


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loader', type=str, required=True, help='The data loader used to load the experimental data. This can be set to UCR, UEA, INT')
    parser.add_argument('--max_train_length', type=int, default=3000, help='For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length> (defaults to 3000)')
    args = parser.parse_args()
    return args


def main(args):
    start_time = time.time()
    
    # Read the dataset list
    if args.loader in [entry.name for entry in os.scandir('datasets/') if entry.is_dir()]:
        dataset_dir = os.path.join('datasets/', args.loader)
        dataset_list = [entry.name for entry in os.scandir(dataset_dir) if entry.is_dir()]
        dataset_list.sort()
    else:
        raise ValueError(f"Unknown dataset loader: {args.loader}")

    for dataset in dataset_list:
        # Load dataset
        if args.loader == 'UCR':
            loaded_data = datautils.load_UCR(dataset)
        elif args.loader == 'UEA':
            loaded_data = datautils.load_UEA(dataset)
        train_data, train_labels, _, _ = loaded_data
        print(f"------ Loaded dataset: {args.loader}-{dataset}, shape {train_data.shape} ------")

        # Compute similarity matrix (this is instance-wise only)
        bad_datasets = ['DuckDuckGeese',
                        'EigenWorms',
                        'MotorImagery',
                        'PEMS-SF']
        if dataset in bad_datasets:
            dist_metric_list = ['COS', 'EUC']
        else:
            dist_metric_list = ['COS', 'EUC', 'DTW', 'TAM']
        for dist_metric in dist_metric_list:
            test_sim_mat = datautils.get_sim_mat(args.loader, train_data, dataset, dist_metric='EUC', prefix='test')
            train_sim_mat = datautils.get_sim_mat(args.loader, train_data, dataset, dist_metric=dist_metric, prefix='train')
            if train_sim_mat is not None:
                print(f'Metric: {dist_metric}, shape: {train_sim_mat.shape}, max.: {train_sim_mat.max():.2f}, min.: {train_sim_mat.min():.2f}')
        print('--- Similarity precomputing completed, time elapsed: ' + time.strftime('%H:%M:%S', time.gmtime(time.time()-start_time)) + ' ---')

    print('--- Time elapsed in total: ' + time.strftime('%H:%M:%S', time.gmtime(time.time()-start_time)) + ' ---')
    sys.exit(0)


if __name__ == '__main__':
    sys.stdout.reconfigure(line_buffering=True)
    args = parse_args()
    main(args)
    