'''
This script is used to train and evaluate MicroTraffic prediction.
The model is reused from the original paper's code.
'''

import os
import sys
import time
import random
import torch
import argparse
import numpy as np
import pandas as pd
from micro_modules.interaction_model import UQnet
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from micro_modules.utils import *
from micro_modules.train import *
from micro_modules.interaction_dataset import *
from micro_modules.losses import *
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from train_eval_utils.utils_eval import *
from train_eval_utils.utils_data import load_MicroTraffic
import utils_pretain as utils_pre


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='0', help='The gpu number to use for training and inference (defaults to 0 for CPU only, can be "1,2" for multi-gpu)')
    parser.add_argument('--add_time_feature', type=int, default=1, help='Whether to add a time feature to the input data (defaults to 1)')
    parser.add_argument('--epochs', type=int, default=25, help='The number of epochs')
    parser.add_argument('--seed', type=int, default=None, help='The random seed')
    parser.add_argument('--reproduction', type=int, default=1, help='Whether this run is for reproduction, if set to True, the random seed would be fixed (defaults to False)')
    args = parser.parse_args()
    args.add_time_feature = bool(args.add_time_feature)
    args.reproduction = bool(args.reproduction)
    return args


def main(args):
    initial_time = time.time()
    print('Available cpus:', torch.get_num_threads(), 'available gpus:', torch.cuda.device_count())
    
    # Set the random seed, `fix_seed` is defined in `utils_general.py`
    if args.reproduction:
        args.seed = 131 # Fix the random seed for reproduction
    if args.seed is None:
        args.seed = random.randint(0, 1000)
    print(f"Random seed is set to {args.seed}")
    utils_pre.fix_seed(args.seed, deterministic=args.reproduction)
    
    # Initialize the deep learning program, `init_dl_program` is defined in `utils_general.py`
    print(f'--- Cuda available: {torch.cuda.is_available()} ---')
    if torch.cuda.is_available(): 
        print(f'--- Cuda device count: {torch.cuda.device_count()}, Cuda device name: {torch.cuda.get_device_name()}, Cuda version: {torch.version.cuda}, Cudnn version: {torch.backends.cudnn.version()} ---')
    device = utils_pre.init_dl_program(args.gpu)
    print(f'--- Device: {device}, Pytorch version: {torch.__version__} ---')

    # Create the directory to save the evaluation results
    if args.add_time_feature:
        continued_results_dir = './results/evaluation/MicroTraffic_continued_evaluation.csv'
        continued_save_dir = './results/train/MicroTraffic_continued/'
        fixed_results_dir = './results/evaluation/MicroTraffic_fixed_evaluation.csv'
        fixed_save_dir = './results/train/MicroTraffic_fixed/'
    else:
        continued_results_dir = './results/evaluation/MicroTraffic_notime_continued_evaluation.csv'
        continued_save_dir = './results/train/MicroTraffic_notime_continued/'
        fixed_results_dir = './results/evaluation/MicroTraffic_notime_fixed_evaluation.csv'
        fixed_save_dir = './results/train/MicroTraffic_notime_fixed/'
    for save_dir in [continued_save_dir, fixed_save_dir]:
        os.makedirs(save_dir, exist_ok=True)
    # Make sure the directories exist
    print(os.path.exists(os.path.dirname(continued_results_dir)), os.path.exists(os.path.dirname(fixed_results_dir)))

    # Define hyper parameters
    paralist = utils_pre.config_micro()
    paralist['encoder_attention_size'] = 128
    paralist['use_sem'] = False
    paralist['epochs'] = args.epochs
    paralist['mode'] = 'lanescore'
    paralist['prob_mode'] = 'ce'
    paralist['batch_size'] = 8
    train_set = ['train1']
    dataset = 'train1'

    # Initialize evaluation results
    model_list = ['original', 'ts2vec', 'topo-ts2vec', 'ggeo-ts2vec', 'softclt', 'topo-softclt', 'ggeo-softclt']
    pred_metrics = ['min_fde', 'mr_05', 'mr_1', 'mr_2']
    knn_metrics = ['mean_shared_neighbours', 'mean_dist_mrre', 'mean_trustworthiness', 'mean_continuity'] # kNN-based, averaged over various k
    density_metrics = ['density_kl_global_001', 'density_kl_global_01', 'density_kl_global_1', 'density_kl_global_10'] # Density-based
    
    for continue_training, results_dir, save_dir in zip([False, True], [fixed_results_dir, continued_results_dir], [fixed_save_dir, continued_save_dir]):
        def read_saved_results():
            eval_results = pd.read_csv(results_dir)
            eval_results['dataset'] = eval_results['dataset'].astype(str)
            eval_results = eval_results.set_index(['model', 'dataset'])
            return eval_results
        
        if os.path.exists(results_dir):
            eval_results = read_saved_results()
        else:
            metrics = pred_metrics + ['local_'+metric for metric in (knn_metrics+density_metrics)] + ['global_'+metric for metric in (knn_metrics+density_metrics)]
            eval_results = pd.DataFrame(np.zeros((len(model_list), 20), dtype=np.float64), columns=metrics,
                                        index=pd.MultiIndex.from_product([model_list,train_set], names=['model','dataset']))
            eval_results.to_csv(results_dir)

        # Load the dataset
        trainset = InteractionDataset(train_set, 'train', paralist, paralist['mode'], device)
        validationset = InteractionDataset(['val'], 'val', paralist, paralist['mode'], device)
        validation_loader = DataLoader(validationset, batch_size=paralist['batch_size'], shuffle=False)
        testset = InteractionDataset(['test'], 'test', paralist, paralist['mode'], device)

        BATCH_SIZE = paralist['batch_size']
        EPOCH_NUMBER = paralist['epochs']
        loss = OverAllLoss(paralist).to(device)

        for model_type in model_list:
            # Train models
            paralist['resolution'] = 1.
            paralist['inference'] = False

            start_time = time.time()
            if model_type == 'original':
                model = UQnet(paralist, test=False, drivable=False).to(device)
            else:
                if args.add_time_feature:
                    model_dir = f'./results/evaluation/MicroTraffic/{model_type}/train1'
                else:
                    model_dir = f'./results/evaluation/MicroTraffic_notime/{model_type}/train1'
                sp_encoder = utils_pre.define_encoder('MicroTraffic', device, model_dir=model_dir,
                                                      continue_training=continue_training)
                model = UQnet(paralist, test=False, drivable=False, traj_encoder=sp_encoder).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            scheduler_heatmap = StepLR(optimizer, step_size=1, gamma=0.975)
            scheduler_epoch = ReduceLROnPlateau(optimizer, mode='min', factor=0.6, patience=2, cooldown=5,
                                                threshold=1e-3, threshold_mode='rel', min_lr=1e-6)

            if os.path.exists(os.path.join(save_dir, f'decoder_{model_type}.pth')):
                print(f'--- {model_type} has been trained ---')
            else:
                print(f'--- Training {model_type} ---')
                train_model(EPOCH_NUMBER, BATCH_SIZE, trainset, model, optimizer, validation_loader, loss,
                            scheduler_heatmap, scheduler_epoch, mode=paralist['mode'])
                if model_type == 'original':
                    torch.save(model.encoder.state_dict(), os.path.join(fixed_save_dir, f'encoder_{model_type}.pth'))
                    torch.save(model.decoder.state_dict(), os.path.join(fixed_save_dir, f'decoder_{model_type}.pth'))
                    torch.save(model.encoder.state_dict(), os.path.join(continued_save_dir, f'encoder_{model_type}.pth'))
                    torch.save(model.decoder.state_dict(), os.path.join(continued_save_dir, f'decoder_{model_type}.pth'))
                else:
                    torch.save(model.encoder.state_dict(), os.path.join(save_dir, f'encoder_{model_type}.pth'))
                    torch.save(model.decoder.state_dict(), os.path.join(save_dir, f'decoder_{model_type}.pth'))
                print(f'Training time for {model_type}: {time.time() - start_time}')

            # Evaluate models
            if eval_results.loc[(model_type, dataset), 'global_mean_continuity'] > 0:
                print(f'--- {model_type} {dataset} has been evaluated, skipping evaluation ---')
                continue
            paralist['resolution'] = 0.5
            paralist['inference'] = True
            if model_type == 'original':
                model = UQnet(paralist, test=True, drivable=False) # set test=True here
            else:
                sp_encoder = utils_pre.define_encoder('MicroTraffic', device, 
                                                      model_dir=f'./results/evaluation/MicroTraffic/{model_type}/train1',
                                                      continue_training=continue_training)
                model = UQnet(paralist, test=True, drivable=False, traj_encoder=sp_encoder)

            model.encoder.load_state_dict(torch.load(os.path.join(save_dir, f'encoder_{model_type}.pth'), 
                                                     map_location=device, weights_only=True))
            model.decoder.load_state_dict(torch.load(os.path.join(save_dir, f'decoder_{model_type}.pth'), 
                                                     map_location=device, weights_only=True))
            model = model.to(device)
            model.eval()
            Yp, Ua, Ue, Y = inference_model([model], testset, paralist)

            # Prediction evaluation
            min_fde, mr_list = ComputeError(Yp, Y, r_list=[0.5,1.,2.], sh=6) # r is the radius of error in meters
            pred_results = {'min_fde': min_fde, 'mr_05': mr_list[0], 'mr_1': mr_list[1], 'mr_2': mr_list[2]}

            # Encoding evaluation
            _, _, test_data = load_MicroTraffic(train_set, dataset_dir='./datasets')
            test_labels = np.zeros(test_data.shape[0])
            local_dist_dens_results = evaluate(test_data, test_labels, model, batch_size=128, local=True, save_latents=False)
            global_dist_dens_results = evaluate(test_data, test_labels, model, batch_size=128, local=False, save_latents=False)

            key_values = {**pred_results, **local_dist_dens_results, **global_dist_dens_results}
            keys = list(key_values.keys())
            values = np.array(list(key_values.values())).astype(np.float64)
            eval_results = read_saved_results() # read saved results again to avoid overwriting
            eval_results.loc[(model_type, dataset), keys] = values

            # Save evaluation results per dataset and model
            eval_results.to_csv(results_dir)

    print(f"Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - initial_time))}")
    sys.exit(0)


if __name__ == '__main__':
    sys.stdout.reconfigure(line_buffering=True)
    args = parse_args()
    main(args)