from __future__ import print_function
import os
import csv
import math
import time
import wandb
import argparse
import numpy as np
import pandas as pd
import prettytable as pt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from dataloader_avvp import on_LLP_dataset
from main_network_avvp import PreFM_Net
from utils.eval_metrics import segment_level, event_level_all
import datetime

def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def calculate_grad_norm(model):
    total_norm = 0
    parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    for p in parameters:
        param_norm = p.grad.detach().data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5

    return total_norm

def lr_warm_up_cos_anneal(optimizer, cur_epoch, warm_up_epoch, max_epoch, lr_min, lr_max):
    if cur_epoch < warm_up_epoch:
        lr = cur_epoch / warm_up_epoch * lr_max
    else:
        lr = (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos((cur_epoch-warm_up_epoch)/(max_epoch-warm_up_epoch)*math.pi)))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_evaluation_result_table(mode, F_scores_dict):
    f_scores_tb = pt.PrettyTable()
    # 12
    fields = ["Dataset", 'sf_a', 'sf_v', 'sf_av', 'pfmap_a', 'pfmap_v', 'pfmap_av', 'ef05_a', 'ef05_v', 'ef05_av', 'efavg_a', 'efavg_v', 'efavg_av']
    f_scores_tb.field_names = fields
    F_scores_list = [mode] + ['{:.2f}'.format(F_scores_dict[key]) for key in fields if key != 'Dataset']
    f_scores_tb.add_row(F_scores_list)

    return f_scores_tb

import pdb
def gaussian_weights(num_steps, center, sigma):

    t = torch.arange(num_steps, dtype=torch.float32)
    weights = torch.exp(-0.5 * ((t - center) / sigma)**2)

    weights = weights * (num_steps / weights.sum())
    return weights  # shape: [num_steps]

def kd_cos_loss( stu, tea,time_weights):


    cos = F.cosine_similarity(stu, tea, dim=-1)

    lens = time_weights.shape[0]
    weighted_loss = cos * time_weights.view(1, lens, 1)
    mean_weight = torch.mean(time_weights,dim=0)
    # pdb.set_trace()

    return mean_weight-weighted_loss.mean()


def train(args, model, train_loader, optimizer, criterion, epoch, device):
    model.train()
    train_loss = {'total': 0, 'loss_cls': 0, 'loss_kd_curr': 0,'loss_kd_futu': 0, \
                    # 'loss_event': 0, \
                        'loss_all': 0}

    for batch_idx, batch_data in enumerate(train_loader):
        
        audio, visual, st = batch_data['audio'].to(device), batch_data['visual'].to(device), batch_data['st'].to(device) # B 10 768
        curr_label_a, curr_label_v = batch_data['curr_label_a'].float().to(device), batch_data['curr_label_v'].float().to(device)# B 10 25
        futu_label_a, futu_label_v = batch_data['futu_label_a'].float().to(device), batch_data['futu_label_v'].float().to(device) # B 5 25
        curr_f_label_a, curr_f_label_v  = batch_data['curr_f_label_a'].to(device), batch_data['curr_f_label_v'].to(device) # B 10 1536
        futu_f_label_a, futu_f_label_v  = batch_data['futu_f_label_a'].to(device), batch_data['futu_f_label_v'].to(device) # B 5 1536

        # import pdb;pdb.set_trace()
        all_label_a = torch.cat([curr_label_a, futu_label_a], dim=1)
        all_label_v = torch.cat([curr_label_v, futu_label_v], dim=1) # B 15 25
        all_label = torch.stack([all_label_a, all_label_v], dim = 2) # B 15 2 25

        curr_f_label = torch.stack([curr_f_label_a, curr_f_label_v], dim=2) # B 10 2 1536
        futu_f_label = torch.stack([futu_f_label_a, futu_f_label_v], dim=2) # B 10 2 1536



        batch_size = visual.size(0)

        optimizer.zero_grad()

        all_frame_logits, curr_f_pred, curr_f, futu_f_pred, futu_f = model(audio, visual, st, curr_f_label, futu_f_label)
        # import pdb;pdb.set_trace()

        time_weights = gaussian_weights(num_steps = 20, center = 9.5, sigma = 4).to(device)
        l_cls = F.binary_cross_entropy_with_logits(all_frame_logits, all_label, reduction='none')
        loss_cls = l_cls * time_weights[:10+args.future_length].view(1, 10 + args.future_length, 1, 1)
        loss_cls = loss_cls.mean()

        loss_kd_curr = kd_cos_loss(curr_f_pred, curr_f, time_weights[:10])
        loss_kd_futu = kd_cos_loss(futu_f_pred, futu_f, time_weights[10:10+args.future_length])

        loss = loss_cls + loss_kd_curr + loss_kd_futu

        train_loss['loss_cls'] += (loss_cls.item()*batch_size)
        train_loss['loss_kd_curr'] += (loss_kd_curr.item()*batch_size)
        train_loss['loss_kd_futu'] += (loss_kd_futu.item()*batch_size)
        train_loss['loss_all'] += (loss.item()*batch_size)
        train_loss['total'] += batch_size

        loss.backward()
        if args.grad_norm > 0:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
        total_grad_norm = calculate_grad_norm(model)
        optimizer.step()

        # Log loss every 20 iterations
        if batch_idx % (len(train_loader) // 20 ) == 0:
            print(f"Epoch [{epoch}], Iter [{batch_idx}/{len(train_loader)}]: "
                  f"Loss cls: {loss_cls.item():.4f},  "
                  f"Loss kd_curr: {loss_kd_curr.item():.4f},  "
                  f"Loss kd_future: {loss_kd_futu.item():.4f},  "
                  f"Total Loss: {loss.item():.4f}")

    num_train_data = train_loss['total']
    train_loss = {key: (float(value) / num_train_data) for key, value in train_loss.items() if 'loss' in key}

    return train_loss


from sklearn.metrics import average_precision_score
def compute_perframe_map(pred, gt):

    T = pred.shape[0]
    ap_list = []
    for i in range(T):

        if np.sum(gt[i]) == 0:
            ap = 0.0
        else:
            ap = average_precision_score(gt[i], pred[i])
        ap_list.append(ap)
    return np.mean(ap_list)

def eval(args, model, data_loader, label_dir, criterion, device, mode):
    categories = ['Speech', 'Car', 'Cheering', 'Dog', 'Cat', 'Frying_(food)',
                  'Basketball_bounce', 'Fire_alarm', 'Chainsaw', 'Cello', 'Banjo',
                  'Singing', 'Chicken_rooster', 'Violin_fiddle', 'Vacuum_cleaner',
                  'Baby_laughter', 'Accordion', 'Lawn_mower', 'Motorcycle', 'Helicopter',
                  'Acoustic_guitar', 'Telephone_bell_ringing', 'Baby_cry_infant_cry', 'Blender',
                  'Clapping']
    model.eval()
    '''
    per-frame map,
    segment-level F,
    event-level F with varied tiou(0.1:0.1:1),
    '''
    pfmap_a = []
    pfmap_v = []
    pfmap_av = []
    sf_a = []
    sf_v = []
    sf_av = []
    ef_a = [[] for _ in range(9)]
    ef_v = [[] for _ in range(9)]
    ef_av = [[] for _ in range(9)]

    # only curr data is available when val or test.
    val_loss = {'total': 0, 'loss_cls_curr': 0, 'loss_kd_curr': 0, 'loss_all': 0}

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(data_loader):
            
            video_name = batch_data['name'][0]
            length = batch_data['length'][0]


            audio, visual, st = batch_data['audio'].to(device), batch_data['visual'].to(device), batch_data['st'].to(device) # B 10 768
            curr_label_a, curr_label_v = batch_data['curr_label_a'].float().to(device), batch_data['curr_label_v'].float().to(device)# B 10 25
            curr_f_label_a, curr_f_label_v  = batch_data['curr_f_label_a'].to(device), batch_data['curr_f_label_v'].to(device) # B 10 1536

            curr_label = torch.stack([curr_label_a, curr_label_v], dim = 2) # B 10 2 25
            curr_f_label = torch.stack([curr_f_label_a, curr_f_label_v], dim=2) # B 10 2 1536

            batch_size = visual.size(0)

            all_frame_logits, curr_f_pred, curr_f = model(audio, visual, st, curr_f_label) 

            # ================= Calculate Validation Loss ===================
            if criterion != None:
                # loss_video = criterion(output, labels)
                time_weights = gaussian_weights(num_steps = 20, center = 9.5, sigma = 4).to(device)
                l_cls_curr = F.binary_cross_entropy_with_logits(all_frame_logits[:,:10,:,:], curr_label, reduction='none')
                loss_cls_curr = l_cls_curr * time_weights[:10].view(1, 10, 1, 1)
                loss_cls_curr = loss_cls_curr.mean()

                loss_kd_curr = kd_cos_loss(curr_f_pred, curr_f, time_weights[:10])

                loss = loss_cls_curr + loss_kd_curr

                val_loss['loss_cls_curr'] += (loss_cls_curr.item()*batch_size)
                val_loss['loss_kd_curr'] += (loss_kd_curr.item()*batch_size)
                val_loss['loss_all'] += (loss.item()*batch_size)
            val_loss['total'] += batch_size
            # ================================================================

            GT_a = np.load(os.path.join(label_dir, 'audio', video_name + '.npy')) # T,25
            GT_v = np.load(os.path.join(label_dir, 'visual', video_name + '.npy')) # T,25

            curr_prob = torch.sigmoid(all_frame_logits[:length,9,:,:]) # B,15,2,25-->B,2,25-->T,2,25
            Pre_av = curr_prob.cpu().detach().numpy() # T,2,25

            # pf_map
            map_a = compute_perframe_map(Pre_av[:,0,:], GT_a)
            map_v = compute_perframe_map(Pre_av[:,1,:], GT_v)
            map_av = compute_perframe_map(np.sqrt(Pre_av[:,0,:] * Pre_av[:,1,:]), GT_a * GT_v)
            pfmap_a.append(map_a)
            pfmap_v.append(map_v)
            pfmap_av.append(map_av)

            # seg_f
            GT_a_t = np.transpose(GT_a)
            GT_v_t = np.transpose(GT_v) # 25,T
            Pre_av = (Pre_av >= 0.5).astype(np.int_)
            Pre_av_t = Pre_av.transpose(2,1,0) # 25,2,T
            f_a = segment_level(Pre_av_t[:,0,:], GT_a_t)
            f_v = segment_level(Pre_av_t[:,1,:], GT_v_t)
            f_av = segment_level(Pre_av_t[:,0,:] * Pre_av_t[:,1,:], GT_a_t * GT_v_t)
            sf_a.append(f_a)
            sf_v.append(f_v)
            sf_av.append(f_av)

            # import pdb;pdb.set_trace()
            # event f
            true_length = Pre_av.shape[0]
            event_scores_a = event_level_all(Pre_av_t[:,0,:], GT_a_t, true_length, tiou_list=np.arange(0.1, 1.0, 0.1))
            for i, score in enumerate(event_scores_a):
                ef_a[i].append(score)
            
            event_scores_v = event_level_all(Pre_av_t[:,1,:], GT_v_t, true_length, tiou_list=np.arange(0.1, 1.0, 0.1))
            for i, score in enumerate(event_scores_v):
                ef_v[i].append(score)

            event_scores_av = event_level_all(Pre_av_t[:,0,:] * Pre_av_t[:,1,:], GT_a_t * GT_v_t, true_length, tiou_list=np.arange(0.1, 1.0, 0.1))
            for i, score in enumerate(event_scores_av):
                ef_av[i].append(score)

    ef_a_list = []
    ef_v_list = []
    ef_av_list = []
    for i in range(9):
        ef_a_list.append(100 * np.mean(np.array(ef_a[i])))
        ef_v_list.append(100 * np.mean(np.array(ef_v[i])))
        ef_av_list.append(100 * np.mean(np.array(ef_av[i])))
    

    F_scores = {
        'sf_a': 100 * np.mean(np.array(sf_a)),
        'sf_v': 100 * np.mean(np.array(sf_v)),
        'sf_av': 100 * np.mean(np.array(sf_av)),

        'pfmap_a': 100 * np.mean(np.array(pfmap_a)),
        'pfmap_v': 100 * np.mean(np.array(pfmap_v)),
        'pfmap_av': 100 * np.mean(np.array(pfmap_av)),

        **{f'ef{i+1:02d}_a': ef_a_list[i] for i in range(9)},
        **{f'ef{i+1:02d}_v': ef_v_list[i] for i in range(9)},
        **{f'ef{i+1:02d}_av': ef_av_list[i] for i in range(9)},
        'efavg_a': np.mean(ef_a_list),
        'efavg_v': np.mean(ef_v_list),
        'efavg_av': np.mean(ef_av_list)
    }

    num_train_data = val_loss['total']
    val_loss = {key: (float(value) / num_train_data) for key, value in val_loss.items() if 'loss' in key}

    return F_scores, val_loss


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Official Implementation of PreFM-On-AVVP')

    parser.add_argument("--pd_dir_train", type=str, default='./data_online/train_pd_online.csv',
                        help="pd_dir_train")
    parser.add_argument("--pd_dir_val", type=str, default='./data_online/val_pd_online.csv',
                        help="pd_dir_val")
    parser.add_argument("--pd_dir_test", type=str, default='./data_online/test_pd_online.csv',
                        help="pd_dir_test")
    
    parser.add_argument("--audio_dir", type=str, default='./data_online/feats/clap',
                        help="audio features dir")
    parser.add_argument("--visual_dir", type=str, default='./data_online/feats/clip',
                        help="2D visual features dir")
    parser.add_argument("--st_dir", type=str, default='./data_online/feats/r2plus1d_18',
                        help="3D visual features dir")
    parser.add_argument("--label_dir", type=str, default='./data_online/label',
                        help="segment-level pseudo labels dir")
    parser.add_argument("--f_label_dir", type=str, default='./data_online/feature_label',
                        help="feature labels dir")

    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument("--mode", type=str, default='train', choices=['train', 'val', 'test'],
                        help="which mode to use")
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--val_batch_size', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=60)
    parser.add_argument('--optimizer', type=str, default='adamw')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--grad_norm', type=float, default=1.0,
                        help='the value for gradient clipping (0 means no gradient clipping)')

    # optimizer hyper-parameters
    parser.add_argument('--weight_decay', type=float, default=1e-3,
                        help='weight decay for optimizer')
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--eps', type=float, default=1e-8)

    # scheduler hyper-parameters
    parser.add_argument('--scheduler', type=str, default='steplr', help='which scheduler to use')
    parser.add_argument('--stepsize', type=int, default=10, help='step size of learning scheduler')
    parser.add_argument('--gamma', type=float, default=0.1, help='gamma of learning scheduler')
    parser.add_argument('--warm_up_epoch', type=int, default=5, help='the number of epochs for warm up')
    parser.add_argument('--lr_min', type=float, default=1e-6, help='the minimum lr for lr decay')

    # model hyper-parameters
    parser.add_argument("--model", type=str, default='PreFM_Net', help="which model to use")
    parser.add_argument("--input_v_dim", type=int, default=2048)
    parser.add_argument("--input_a_dim", type=int, default=128)
    parser.add_argument("--hidden_dim", type=int, default=512)
    parser.add_argument("--nhead", type=int, default=8)
    parser.add_argument("--ff_dim", type=int, default=1024)
    parser.add_argument("--num_layers", type=int, default=1)
    parser.add_argument("--norm_where", type=str, default="post_norm", choices=['post_norm', 'pre_norm'])

    parser.add_argument("--model_name", type=str,
                        help="the name for the model")
    parser.add_argument("--model_save_dir", type=str, default='models/',
                        help="where to save the trained model")

    # wandb configurations
    parser.add_argument("--use_wandb", action="store_true",
                        help="use wandb or not")
    parser.add_argument("--wandb_project_name", type=str, default='Baseline')
    parser.add_argument("--wandb_run_name", type=str)

    parser.add_argument("--future_length", type=int, default=5)
    parser.add_argument("--current_length", type=int, default=10)
    parser.add_argument("--interval", type=int, default=6)


    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.model_name == None:
        args.model_name = args.model
    if 'CLAP' in args.audio_dir:
        print('reset args.input_a_dim')
        args.input_a_dim = 768
    if 'CLIP' in args.visual_dir:
        print('reset args.input_v_dim')
        args.input_v_dim = 768             # 1024 or 768 (before visual projection / after visual projection)
    print('args =', args)

    if args.mode == 'train':
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        assert not os.path.exists(os.path.join(args.model_save_dir, args.model_name, timestamp)), "{} already exists. Please specify another model_name.".format(args.model_name)

        os.makedirs(os.path.join(args.model_save_dir, args.model_name, timestamp), exist_ok=False)
        args_dict = args.__dict__
        with open(os.path.join(args.model_save_dir, args.model_name, timestamp, "arguments.txt"), 'w') as f:
            f.writelines('-------------------------start-------------------------\n')
            for key, value in args_dict.items():
                f.writelines(key + ': ' + str(value) + '\n')
            f.writelines('--------------------------end--------------------------\n')

    # Initialize wandb
    if args.use_wandb:
        wandb.init(project=args.wandb_project_name)
        if args.wandb_run_name != None:
            wandb.run.name = args.wandb_run_name
        wandb.config.update(args)

    # Set random seed and device
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    seed_everything(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = PreFM_Net(args).to(device)

    categories = ['Speech', 'Car', 'Cheering', 'Dog', 'Cat', 'Frying_(food)',
                    'Basketball_bounce', 'Fire_alarm', 'Chainsaw', 'Cello', 'Banjo',
                    'Singing', 'Chicken_rooster', 'Violin_fiddle', 'Vacuum_cleaner',
                    'Baby_laughter', 'Accordion', 'Lawn_mower', 'Motorcycle', 'Helicopter',
                    'Acoustic_guitar', 'Telephone_bell_ringing', 'Baby_cry_infant_cry', 'Blender',
                    'Clapping']

    if args.mode == 'train':

        train_dataset = on_LLP_dataset(mode=args.mode, pd_dir=args.pd_dir_train, audio_dir=args.audio_dir, visual_dir=args.visual_dir, st_dir=args.st_dir,
                                    label_dir=args.label_dir, f_label_dir=args.f_label_dir, f_lens = args.future_length, c_lens = args.current_length)
        val_dataset   = on_LLP_dataset(mode='val', pd_dir=args.pd_dir_val, audio_dir=args.audio_dir, visual_dir=args.visual_dir, st_dir=args.st_dir,
                                    label_dir=args.label_dir, f_label_dir=args.f_label_dir, f_lens = args.future_length, c_lens = args.current_length)

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=32, pin_memory=True)
        val_loader   = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=15, pin_memory=True)

        test_dataset = on_LLP_dataset(mode='test', pd_dir=args.pd_dir_test, audio_dir=args.audio_dir, visual_dir=args.visual_dir, st_dir=args.st_dir,
                                    label_dir=args.label_dir, f_label_dir=args.f_label_dir, f_lens = args.future_length, c_lens = args.current_length)
        test_loader = DataLoader(test_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=4, pin_memory=True)

        # Create loss function(s)
        criterion = nn.BCELoss()

        # Create optimizer, scheduler
        if args.optimizer == 'adamw':
            optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2), eps=args.eps)
        else:
            optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2), eps=args.eps)

        if args.scheduler == 'steplr':
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        elif args.scheduler == 'warm_up_cos_anneal':
            print('Using hand-made lr scheduler')
        else:
            print('Warning! Not using any lr scheduler!')

        # 'seg_f_a', 'seg_f_v', 'seg_f_av', 'pf_map_a', 'pf_map_v', 'pf_map_av', 'ef05_a', 'ef05_v', 'ef05_av', 'ef_avg_a', 'ef_avg_v', 'ef_avg_av'
        # best_F = {'seg_f_av': 0.0, 'pf_map_av': 0.0, 'ef05_av': 0.0, 'ef_avg_av': 0.0}
        best_F = {'pfmap_av': 0.0}
        best_epoch = 0
        for epoch in range(1, args.epochs + 1):
            train_dataset._init_dataset()
            if args.scheduler == 'warm_up_cos_anneal':
                lr_warm_up_cos_anneal(optimizer, epoch, args.warm_up_epoch, args.epochs, args.lr_min, args.lr)

            cur_lr = optimizer.param_groups[0]['lr']
            start_time = time.time()
            train_loss_dict = train(args, model, train_loader, optimizer, criterion, epoch, device)
            
            if args.scheduler != 'warm_up_cos_anneal':
                scheduler.step()

            if epoch % args.interval == 0:
                F_scores, val_loss_dict = eval(args, model, val_loader, args.label_dir, criterion, device, mode = 'val')
                elapse_time = time.time() - start_time
                
                torch.save(model.state_dict(), os.path.join(args.model_save_dir, args.model_name, "checkpoint_epoch_{}.pt".format(epoch)))
                if F_scores['pfmap_av'] > best_F['pfmap_av']:
                    best_F = F_scores
                    best_epoch = epoch
                    torch.save(model.state_dict(), os.path.join(args.model_save_dir, args.model_name, "checkpoint_best.pt"))
            

                print('Epoch[{}/{}](Time:{:.2f} sec)(lr:{:.6f}) Train Loss: {:.3f} Val Loss: {:.3f}  Val F: {:.3f}, {:.3f}, {:.3f}, {:.3f}'.format(
                        epoch, args.epochs, elapse_time, cur_lr, train_loss_dict['loss_all'], val_loss_dict['loss_all'], F_scores['sf_av'], F_scores['pfmap_av'], F_scores['ef05_av'], F_scores['efavg_av']))
                
                # just test it
                F_scores_test, _ = eval(args, model, test_loader, args.label_dir, criterion=None, device=device, mode = 'test')
                f_scores_tb = get_evaluation_result_table(mode="Test", F_scores_dict=F_scores_test)
                print('Evaluation result:')
                print(F_scores_test)
                print(f_scores_tb)

        print('-'*30)
        print('Best F scores (at epoch {}):'.format(best_epoch))
        f_scores_tb = get_evaluation_result_table(mode="Val", F_scores_dict=best_F)
        print(f_scores_tb)

    # elif args.mode == 'val':
    #     val_dataset = on_LLP_dataset(mode='val', label=args.label_val, audio_dir=args.audio_dir, res152_dir=args.video_dir,
    #                                 r2plus1d_18_dir=args.st_dir, v_pseudo_data_dir=args.v_pseudo_data_dir, a_pseudo_data_dir=args.a_pseudo_data_dir)
    #     val_loader  = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
    #     model.load_state_dict(torch.load(os.path.join(args.model_save_dir, args.model_name, "checkpoint_best.pt")))

    #     # Evaluation
    #     F_scores_val, _ = eval(args, model, val_loader, './data', criterion=None, device=device)
    #     f_scores_tb = get_evaluation_result_table(mode="Val", F_scores_dict=F_scores_val)
    #     print('Evaluation result:')
    #     print(f_scores_tb)

    elif args.mode == 'test':
        test_dataset = on_LLP_dataset(mode='test', pd_dir=args.pd_dir_test, audio_dir=args.audio_dir, visual_dir=args.visual_dir, st_dir=args.st_dir,
                                    label_dir=args.label_dir, f_label_dir=args.f_label_dir, f_lens = args.future_length, c_lens = args.current_length)
        test_loader = DataLoader(test_dataset, args.val_batch_size, shuffle=False, num_workers=4, pin_memory=True)
        model.load_state_dict(torch.load(os.path.join(args.model_save_dir, args.model_name, "checkpoint_best.pt")))

        # Evaluation
        F_scores_test, _ = eval(args, model, test_loader, args.label_dir, criterion=None, device=device, mode = 'test')
        f_scores_tb = get_evaluation_result_table(mode="Test", F_scores_dict=F_scores_test)
        print('Evaluation result:')
        print(F_scores_test)
        print(f_scores_tb)

    else:
        print('Please specify args.mode!')
        

if __name__ == '__main__':
    main()
