'''
This script is used to train and evaluate models.
Tuned hyperparameters are loaded from the hyper_parameters directory.
The trained models are saved in the results directory.
The evaluation results are saved in the evaluation directory.
'''

import os
import sys
import time
import glob
import numpy as np
import pandas as pd
import argparse
from tasks.paramsearch import *
import train_eval_utils.utils_data as datautils
import torch
import random
from model import traffic_scl
from train_eval_utils.utils_general import *
from train_eval_utils.utils_eval import *
from tasks.classification import eval_classification
from tasks.clustering import eval_clustering


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loader', type=str, required=True, help='The data loader used to load the experimental data. This can be set to UCR, UEA, INT')
    parser.add_argument('--add_time_feature', type=int, default=1, help='Whether to add a time feature to the input data (defaults to 1)')
    parser.add_argument('--gpu', type=str, default='0', help='The gpu number to use for training and inference (defaults to 0 for CPU only, can be "1,2" for multi-gpu)')
    parser.add_argument('--seed', type=int, default=None, help='The random seed')
    parser.add_argument('--reproduction', type=int, default=1, help='Whether this run is for reproduction, if set to True, the random seed would be fixed (defaults to True)')
    parser.add_argument('--reverse_list', type=int, default=0, help='Whether to reverse the dataset list (defaults to 0)')
    args = parser.parse_args()
    args.add_time_feature = bool(args.add_time_feature)
    args.reproduction = bool(args.reproduction)
    args.reverse_list = bool(args.reverse_list)

    # Set default parameters
    args.sliding_padding = 0
    args.repr_dims = 320
    args.tau_inst = 0
    args.tau_temp = 0
    args.temporal_hierarchy = None
    args.regularizer = None
    args.bandwidth = 1.
    args.iters = None
    args.epochs = 100
    args.batch_size = 8
    args.lr = 0.001
    args.weight_lr = 0.01

    return args


def main(args):
    initial_time = time.time()
    print('Available cpus:', torch.get_num_threads(), 'available gpus:', torch.cuda.device_count())
    
    # Set the random seed
    if args.reproduction:
        args.seed = 131 # Fix the random seed for reproduction
    if args.seed is None:
        args.seed = random.randint(0, 1000)
    print(f"Random seed is set to {args.seed}")
    fix_seed(args.seed, deterministic=args.reproduction)
    
    # Initialize the deep learning program
    print(f'--- Cuda available: {torch.cuda.is_available()} ---')
    if torch.cuda.is_available(): 
        print(f'--- Cuda device count: {torch.cuda.device_count()}, Cuda device name: {torch.cuda.get_device_name()}, Cuda version: {torch.version.cuda}, Cudnn version: {torch.backends.cudnn.version()} ---')
    device = init_dl_program(args.gpu)
    print(f'--- Device: {device}, Pytorch version: {torch.__version__} ---')

    # Create the directory to save the evaluation results
    if args.add_time_feature:
        run_dir = f'results/evaluation/{args.loader}/'
        results_dir = f'results/evaluation/{args.loader}_evaluation.csv'
    else:
        run_dir = f'results/evaluation/{args.loader}_notime/'
        results_dir = f'results/evaluation/{args.loader}_notime_evaluation.csv'
    os.makedirs(run_dir, exist_ok=True)

    # Read the dataset list
    if args.loader in [entry.name for entry in os.scandir('datasets/') if entry.is_dir()]:
        if args.loader == 'UEA':
            dataset_dir = os.path.join('datasets/', args.loader)
            dataset_list = [entry.name for entry in os.scandir(dataset_dir) if entry.is_dir()]
            dataset_list.sort()
        elif args.loader == 'MacroTraffic':
            dataset_list = [['2019']]
        elif args.loader == 'MicroTraffic':
            dataset_list = [['train1']]
    else:
        raise ValueError(f"Unknown dataset loader: {args.loader}")

    # Initialize evaluation results for UEA (traffic tasks will be evaluated using other scripts)
    if args.add_time_feature:
        model_list = ['ts2vec', 'topo-ts2vec', 'ggeo-ts2vec', 'softclt', 'topo-softclt', 'ggeo-softclt']
    else:
        model_list = ['topo-ts2vec', 'ggeo-ts2vec', 'topo-softclt', 'ggeo-softclt']

    if args.loader == 'UEA':
        clf_clr_metrics = ['svm_acc', 'svm_auprc', 'kmeans_ari', 'kmeans_ami'] # Classification and clustering results
        knn_metrics = ['mean_shared_neighbours', 'mean_dist_mrre', 'mean_trustworthiness', 'mean_continuity'] # kNN-based, averaged over various k
        density_metrics = ['density_kl_global_001', 'density_kl_global_01', 'density_kl_global_1', 'density_kl_global_10'] # Density-based

        def read_saved_results():
            eval_results = pd.read_csv(results_dir)
            eval_results['dataset'] = eval_results['dataset'].astype(str)
            eval_results = eval_results.set_index(['model', 'dataset'])
            return eval_results
        
        if os.path.exists(results_dir):
            eval_results = read_saved_results()
        else:
            metrics = clf_clr_metrics + ['local_'+metric for metric in (knn_metrics+density_metrics)] + ['global_'+metric for metric in (knn_metrics+density_metrics)]
            eval_results = pd.DataFrame(np.zeros((len(dataset_list)*len(model_list), 20), dtype=np.float64), columns=metrics,
                                        index=pd.MultiIndex.from_product([model_list,dataset_list], names=['model','dataset']))
            eval_results.to_csv(results_dir)
    
    bad_datasets = ['DuckDuckGeese',
                    'EigenWorms',
                    'MotorImagery',
                    'PEMS-SF'] # Datasets that are too resource-consuming to compute DTW or TAM
    
    # Evaluate for each dataset
    if args.reverse_list:
        dataset_list = dataset_list[::-1]
    for dataset in dataset_list:
        # Load dataset
        if args.loader == 'UEA':
            loaded_data = datautils.load_UEA(dataset)
            train_data, train_labels, test_data, test_labels = loaded_data
        elif args.loader == 'MacroTraffic':
            loaded_data = datautils.load_MacroTraffic(dataset, time_interval=5, horizon=15, observation=20)
            train_data, val_data, test_data = loaded_data
            dataset = '2019'
        elif args.loader == 'MicroTraffic':
            loaded_data = datautils.load_MicroTraffic(dataset)
            train_data, val_data, test_data = loaded_data
            dataset = 'train'+''.join(dataset).replace('train', '')
        
        # Load tuned hyperparameters
        if args.add_time_feature:
            tuned_params_dir = f'results/hyper_parameters/{args.loader}/{dataset}_tuned_hyperparameters.csv'
        else:
            tuned_params_dir = f'results/hyper_parameters/{args.loader}_notime/{dataset}_tuned_hyperparameters.csv'
        if os.path.exists(tuned_params_dir):
            tuned_params = pd.read_csv(tuned_params_dir, index_col=0)
        else:
            print(f'****** {tuned_params_dir} not found ******')
            continue

        if args.loader == 'MacroTraffic' or args.loader == 'MicroTraffic':
            args.dist_metric = 'EUC'
        else:
            if dataset in bad_datasets:
                print(f"Dataset {dataset} is too resource-consuming to compute DTW or TAM, switch to EUC by default.")
                args.dist_metric = 'EUC'
            else:
                args.dist_metric = 'DTW'
        sim_mat = datautils.get_sim_mat(args.loader, train_data, dataset, args.dist_metric)
        
        train_size = train_data.shape[0]
        feature_size = train_data.shape[-1]
        if args.loader != 'UEA':
            args.epochs = 300
            verbose = 2
        else:
            if train_size < 1000 and train_data.shape[-2] < 1000:
                args.epochs = 1000
            elif train_size < 3000:
                args.epochs = 600
            else:
                args.epochs = 400
            verbose = 1

        # Iterate over different models
        for model_type in model_list:
            if args.loader == 'UEA':
                if eval_results.loc[(model_type, dataset), 'global_mean_continuity'] > 0:
                    final_epoch = eval_results.loc[(model_type, dataset), 'model_used'].split('epo')[0].split('_')[-1]
                    if final_epoch[-2:] != '00':
                        print(f'--- {model_type} {dataset} has been evaluated (not 00), skipping evaluation ---')
                        continue
                    elif int(final_epoch) == args.epochs:
                        print(f'--- {model_type} {dataset} has been trained (==epochs), skipping evaluation ---')
                        continue
            start_time = time.time()
            save_dir = os.path.join(run_dir, f'{model_type}/{dataset}')
            os.makedirs(save_dir, exist_ok=True)

            # Set hyperparameters and configure model
            args = load_tuned_hyperparameters(args, tuned_params, model_type)
            model_config = configure_model(args, feature_size, device)

            # Train model if not already trained
            if os.path.exists(f'{save_dir}/loss_log.csv'):
                print(f'--- {model_type} {dataset} has been trained, loading final model ---')
            else:
                # Create model
                model_config['after_epoch_callback'] = save_checkpoint_callback(save_dir, 0, unit='epoch')
                model = traffic_scl(args.loader, **model_config)

                scheduler = 'reduced'
                print(f'--- {model_type}_{dataset} training with ReduceLROnPlateau scheduler ---')
                soft_assignments = datautils.assign_soft_labels(sim_mat, args.tau_inst)
                loss_log = model.fit(dataset, train_data, soft_assignments, args.epochs, args.iters, scheduler, verbose=verbose)
                # Save loss log
                save_loss_log(loss_log, save_dir, regularizer=args.regularizer)
                print(f'Training time elapsed: ' + time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time)))
            
            # Reserve the latest model and remove the rest
            existing_models = glob.glob(f'{save_dir}/*_net.pth')
            if len(existing_models)>1:
                existing_models.sort(key=os.path.getmtime, reverse=True)
                for model_epoch in existing_models[1:]:
                    os.remove(model_epoch)
                    if model_type in ['topo-ts2vec', 'ggeo-ts2vec', 'topo-softclt', 'ggeo-softclt']:
                        os.remove(model_epoch.replace('_net.pth', '_loss_log_vars.npy'))
            best_model = 'model' + existing_models[0].split('model')[-1].split('_net')[0]

            # Load best model for evaluation
            model = traffic_scl(args.loader, **model_config)
            model.load(f'{save_dir}/{best_model}')

            # Evaluate model for UEA datasets
            if args.loader == 'UEA':
                print(f'Evaluating with {best_model} ...')

                # Reload data for evaluation
                loaded_data = datautils.load_UEA(dataset)
                train_data, train_labels, test_data, test_labels = loaded_data

                ## classification and clustering results
                _, acc = eval_classification(model, train_data, train_labels, test_data, test_labels)
                test_ari, test_ami = eval_clustering(model, train_data, train_labels, test_data, test_labels)
                clf_clr_results = {'svm_acc': acc['acc'], 
                                   'svm_auprc': acc['auprc'], 
                                   'kmeans_ari': test_ari,
                                   'kmeans_ami': test_ami}
                
                ## distance and density results
                local_dist_dens_results = evaluate(test_data, test_labels, model, batch_size=128, local=True, save_latents=False, save_dir=save_dir)
                global_dist_dens_results = evaluate(test_data, test_labels, model, batch_size=128, local=False, save_latents=True, save_dir=save_dir)

                ## loss results
                test_data, test_labels = datautils.modify_train_data(test_data, test_labels)
                test_sim_mat = datautils.get_sim_mat(args.loader, test_data, dataset, args.dist_metric, prefix='test')
                test_soft_assignments = datautils.assign_soft_labels(test_sim_mat, args.tau_inst)
                loss_results = model.compute_loss(test_data, test_soft_assignments, non_regularized=False)
                loss_results = {'scl_loss': loss_results[1],
                                'sp_loss': loss_results[3] if args.regularizer is not None else np.nan}

                # Save evaluation results
                key_values = {**clf_clr_results, **loss_results, **local_dist_dens_results, **global_dist_dens_results}
                keys = list(key_values.keys())
                values = np.array(list(key_values.values())).astype(np.float64)
                eval_results = read_saved_results() # read saved results again to avoid overwriting
                eval_results.loc[(model_type, dataset), keys] = values
                eval_results.loc[(model_type, dataset), 'model_used'] = best_model

                # Save evaluation results per dataset and model
                eval_results.to_csv(results_dir)
            else:
                print(f'Best model {best_model} will be evaluated in traffic prediction tasks.')
            
    print('--- Total time elapsed: ' + time.strftime('%H:%M:%S', time.gmtime(time.time() - initial_time)) + ' ---')
    sys.exit(0)


if __name__ == '__main__':
    sys.stdout.reconfigure(line_buffering=True)
    args = parse_args()
    main(args)

