'''
This script is used to train the model on a dataset or a list of datasets.
The model is trained using the training data and the similarity matrix precomputed in the precompute_distance.py script.
The model is saved after training and can be used for evaluation.
Note: this script is not used in the experiments, it is only used for helping readers to get familiar with the training process.
'''

import torch
import pandas as pd
import argparse
import os
import sys
import time
import random
from model import traffic_scl
import train_eval_utils.utils_data as datautils
from train_eval_utils.utils_general import *
from tasks.classification import eval_classification
from tasks.clustering import eval_clustering

    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loader', type=str, required=True, help='The data loader used to load the experimental data. This can be set to UCR, UEA, INT')
    parser.add_argument('--add_time_feature', type=int, default=1, help='Whether to add a time feature to the input data (defaults to 1)')
    parser.add_argument('--dataset', type=str, default=None, help='The dataset name used for training and evaluation (defaults to None, which means all datasets in the loader will be used)')
    parser.add_argument('--sliding_padding', type=int, default=0, help='The padding length (history period) for representation encoding in forecasting task (defaults to 0)')
    parser.add_argument('--dist_metric', type=str, default='DTW', help='The distance metric used to calculate the similarity matrix (defaults to DTW, other options are TAM, COS, EUC, GAK)')
    parser.add_argument('--tau_inst', type=float, default=0, help='The temperature parameter for the instance-wise loss (defaults to 0)')
    parser.add_argument('--tau_temp', type=float, default=0, help='The temperature parameter for the temporal loss (defaults to 0)')
    parser.add_argument('--temporal_hierarchy', type=str, default=None, help='The type of temporal hierarchy used in hierarchical contrastive loss (defaults to None, options are "linear", "exponential")')
    parser.add_argument('--regularizer', type=str, default=None, help='The regularizer used to reserve data structure (defaults to None, options are "topology", "geometry")')
    parser.add_argument('--bandwidth', type=float, default=1., help='The bandwidth parameter for geometry regularizer (defaults to 50, needs to be tuned)')
    parser.add_argument('--gpu', type=str, default='0', help='The gpu number to use for training and inference (defaults to 0 for CPU only, can be "1,2" for multi-gpu)')
    parser.add_argument('--batch_size', type=int, default=8, help='The batch size (defaults to 8)')
    parser.add_argument('--lr', type=float, default=0.001, help='The learning rate (defaults to 0.001)')
    parser.add_argument('--weight_lr', type=float, default=0.05, help='The learning rate for the weight parameters (defaults to 0.05)')
    parser.add_argument('--repr_dims', type=int, default=320, help='The representation dimension (defaults to 256)')
    parser.add_argument('--iters', type=int, default=None, help='The number of iterations')
    parser.add_argument('--epochs', type=int, default=None, help='The number of epochs')
    parser.add_argument('--scheduler', type=str, default='constant', help='The learning rate scheduler used for training (defaults to "constant", alternatively "reduced")')
    parser.add_argument('--save_every', type=int, default=None, help='Save the checkpoint every <save_every> iterations/epochs (defaults to None, set to 0 to save the last model only')
    parser.add_argument('--seed', type=int, default=None, help='The random seed')
    parser.add_argument('--reproduction', type=int, default=0, help='Whether this run is for reproduction, if set to True, the random seed would be fixed (defaults to False)')
    parser.add_argument('--tuned', type=int, default=0, help='Whether this run uses tuned hyperparameters (defaults to False)')
    parser.add_argument('--eval', type=int, default=0, help='Whether to perform self-evaluation (defaults to False)')
    args = parser.parse_args()
    args.add_time_feature = bool(args.add_time_feature)
    args.reproduction = bool(args.reproduction)
    args.tuned = bool(args.tuned)
    args.eval = bool(args.eval)
    return args


def main(args):
    initial_time = time.time()
    print('Available cpus:', torch.get_num_threads(), 'available gpus:', torch.cuda.device_count())
    
    # Set the random seed, `fix_seed` is defined in `utils_general.py`
    if args.reproduction:
        args.seed = 131 # Fix the random seed for reproduction
    if args.seed is None:
        args.seed = random.randint(0, 1000)
    print(f"Random seed is set to {args.seed}")
    fix_seed(args.seed, deterministic=args.reproduction)
    
    # Initialize the deep learning program, `init_dl_program` is defined in `utils_general.py`
    print(f'--- Cuda available: {torch.cuda.is_available()} ---')
    if torch.cuda.is_available(): 
        print(f'--- Cuda device count: {torch.cuda.device_count()}, Cuda device name: {torch.cuda.get_device_name()}, Cuda version: {torch.version.cuda}, Cudnn version: {torch.backends.cudnn.version()} ---')
    device = init_dl_program(args.gpu)
    print(f'--- Device: {device}, Pytorch version: {torch.__version__} ---')

    # Set result-saving directory and save the arguments for this run
    if args.tau_inst == 0 and args.tau_temp == 0:
        if args.regularizer is None:
            model_type = 'ts2vec'
        elif args.regularizer == 'topology':
            model_type = 'topo-ts2vec'
        elif args.regularizer == 'geometry':
            model_type = 'ggeo-ts2vec'
    else:
        if args.regularizer is None:
            model_type = 'softclt'
        elif args.regularizer == 'topology':
            model_type = 'topo-softclt'
        elif args.regularizer == 'geometry':
            model_type = 'ggeo-softclt'
    run_dir = f'results/train/{args.loader}/'+model_type+'/'
    
    if args.tuned:
        run_dir += 'tuned'
        os.makedirs(run_dir, exist_ok=True)
    else:
        if args.regularizer is not None:
            if args.regularizer == 'geometry':
                run_dir += f'bw={round(args.bandwidth,1)}_'
            run_dir += f'tau_inst={args.tau_inst}_tau_temp={args.tau_temp}_hier={args.temporal_hierarchy}_'
        run_dir += f'{args.scheduler}_bs={round(args.batch_size)}_lr={round(args.lr,4)}'
        os.makedirs(run_dir, exist_ok=True)
        # Save the arguments for this run
        with open(f'{run_dir}/args.txt', 'w') as f:
            f.write(str(args))

    # Read the dataset list
    if args.loader in [entry.name for entry in os.scandir('datasets/') if entry.is_dir()]:
        if args.dataset is None:
            if args.loader == 'UEA':
                dataset_dir = os.path.join('datasets/', args.loader)
                dataset_list = [entry.name for entry in os.scandir(dataset_dir) if entry.is_dir()]
                dataset_list.sort()
            elif args.loader == 'MacroTraffic':
                dataset_list = [['2019']]
            elif args.loader == 'MicroTraffic':
                dataset_list = [['train1']]
        else:
            if args.loader == 'UEA':
                dataset_list = [args.dataset]
            elif args.loader == 'MacroTraffic':
                dataset_list = [[args.dataset]]
            elif args.loader == 'MicroTraffic':
                dataset_list = [[args.dataset]]
    else:
        raise ValueError(f"Unknown dataset loader: {args.loader}")

    # Train the model for each dataset
    if args.eval:
        eval_results = dict()
    
    bad_datasets = ['DuckDuckGeese',
                    'EigenWorms',
                    'MotorImagery',
                    'PEMS-SF'] # Datasets that are too resource-consuming to compute DTW or TAM
    for dataset in dataset_list:
        start_time = time.time()

        # Load dataset
        if args.loader == 'UEA':
            loaded_data = datautils.load_UEA(dataset)
            train_data, train_labels, test_data, test_labels = loaded_data
        elif args.loader == 'MacroTraffic':
            loaded_data = datautils.load_MacroTraffic(dataset, time_interval=5, horizon=15, observation=20)
            train_data, _, _ = loaded_data
            dataset = '2019'
        elif args.loader == 'MicroTraffic':
            loaded_data = datautils.load_MicroTraffic(dataset)
            train_data, _, _ = loaded_data
            dataset = 'train'+''.join(dataset).replace('train', '')

        save_dir = os.path.join(run_dir, dataset)
        os.makedirs(save_dir, exist_ok=True)

        # Load precomputed similarity matrix
        if args.loader == 'MacroTraffic' or args.loader == 'MicroTraffic':
            args.dist_metric = 'EUC'
        else:
            if dataset in bad_datasets:
                print(f"Dataset {dataset} is too resource-consuming to compute DTW or TAM, switch to EUC by default.")
                args.dist_metric = 'EUC'
        sim_mat = datautils.get_sim_mat(args.loader, train_data, dataset, args.dist_metric)
        soft_assignments = datautils.assign_soft_labels(sim_mat, args.tau_inst)
        if soft_assignments is None:
            print('Soft assignment is not used in this run.')

        # Load tuned hyperparameters
        if args.tuned:
            tuned_params = pd.read_csv(f'results/hyper_parameters/{args.loader}/{dataset}_tuned_hyperparameters.csv', index_col=0)
            try:
                args = load_tuned_hyperparameters(args, tuned_params, model_type)
            except:
                print(f'Hyperparameters for {model_type} are not available, using default hyperparameters instead.')
                pass

        # Set model configuration
        model_config = configure_model(args, train_data.shape[-1], device)

        # Define callback function
        if args.save_every is not None:
            unit = 'epoch' if args.epochs is not None else 'iter'
            model_config[f'after_{unit}_callback'] = save_checkpoint_callback(save_dir, args.save_every, unit)
        
        # Create model
        model = traffic_scl(args.loader, **model_config)

        # Train model
        loss_log = model.fit(dataset, train_data, soft_assignments, args.epochs, args.iters, args.scheduler, verbose=2)

        # Save loss log
        loss_log = loss_log.reshape(-1, loss_log.shape[-1])
        if loss_log.shape[-1] == 2:
            loss_log = pd.DataFrame(loss_log, columns=['loss', 'loss_scl'])
        elif loss_log.shape[-1] == 5:
            loss_log = pd.DataFrame(loss_log, columns=['loss', 'loss_scl', 'log_var_scl', f'loss_{args.regularizer}', f'log_var_{args.regularizer}'])
        elif loss_log.shape[-1] == 7:
            loss_log = pd.DataFrame(loss_log, columns=['loss', 'loss_scl', 'log_var_scl', 'loss_topo', 'log_var_topo', 'loss_geo', 'log_var_geo'])
        loss_log.to_csv(f'{save_dir}/loss_log.csv', index=False)

        print(f'--- {dataset} training time elapsed: ' + time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time)) + ' ---')

        if args.eval:
            if args.loader in ['UCR', 'UEA']:
                _, acc = eval_classification(model, train_data, train_labels, test_data, test_labels)
                print(f"SVM ACC: {acc['acc']:.4f}, AUPRC: {acc['auprc']:.4f}")
                train_score, test_score = eval_clustering(model, train_data, train_labels, test_data, test_labels)
                print(f'k-means train ARI: {train_score:.4f}, test ARI: {test_score:.4f}')
                eval_results = {**eval_results, dataset: {'svm_acc': acc['acc'], 
                                                          'svm_auprc': acc['auprc'], 
                                                          'kmeans_train_ari': train_score, 
                                                          'kmeans_test_ari': test_score}}
                eval_results = pd.DataFrame(eval_results).T
                eval_results.to_csv(f'{save_dir}/eval_results.csv', index=True)
            elif args.loader in ['MicroTraffic', 'MacroTraffic']:
                print('--- Evaluation for application on traffic datasets is not implemented yet ---')

    print('--- ' + model_type + ' training time elapsed: ' + time.strftime('%H:%M:%S', time.gmtime(time.time() - initial_time)) + ' ---')
    sys.exit(0)


if __name__ == '__main__':
    sys.stdout.reconfigure(line_buffering=True)
    args = parse_args()
    main(args)
