from collections import defaultdict
import datetime
import numpy as np
import os
import sys
import time
import torch
from torch.utils.data.dataloader import default_collate
import torchvision

# Custom imports
import models
from samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
from warmup_scheduler import GradualWarmupScheduler
import transforms as T
import utils
from utils import print_or_log


# DICT with number of classes for each  dataset
NUM_CLASSES = {
    'hmdb51': 51,
    'ucf101': 101,
    'kinetics400': 400
}


# Create Finetune Model
class Finetune_Model(torch.nn.Module):
    def __init__(self, base_arch, num_ftrs=512, num_classes=101):
        super(Finetune_Model, self).__init__()
        self.base = base_arch
        self.classifier = torch.nn.Linear(num_ftrs, num_classes)
    
    def forward(self, x):
        x = self.base(x)
        x = self.classifier(x)
        return x


# Load finetune model and training params
def load_model_finetune(args, model, model_name, num_ftrs, num_classes):
    if model_name in ['av_gdt']:
        new_model = Finetune_Model(model, num_ftrs, num_classes)
        return new_model
    elif model_name in ['r2plus1d_18', 'mc3_18', 'r3d_18']:
        model.fc = models.Identity()
        new_model = Finetune_Model(model, num_ftrs, num_classes)
        return new_model


## Aggerate video level softmaxes into an accuracy score
def aggregrate_video_accuracy(softmaxes, labels, topk=(1,), aggregate="mean"):
    maxk = max(topk)
    output_batch = torch.stack(
        [torch.mean(torch.stack(
            softmaxes[sms]),
            0,
            keepdim=False
        ) for sms in softmaxes.keys()])
    num_videos = output_batch.size(0)
    output_labels = torch.stack(
        [labels[video_id] for video_id in softmaxes.keys()])

    _, pred = output_batch.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(output_labels.expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / num_videos))
    return res


def train_one_epoch(
    args,
    model, 
    criterion, 
    optimizer, 
    lr_scheduler, 
    data_loader, 
    device, 
    epoch, 
    print_freq, 
    logger=None,
    writer=None,
):
    model.train()
    metric_logger = utils.MetricLoggerFinetune(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}'))

    header = 'Epoch: [{}]'.format(epoch)
    mode = f'{args.dataset}-{str(args.fold)}/{args.ckpt_epoch}/train'
    for batch_idx, batch in metric_logger.log_every(data_loader, print_freq, header, logger, writer, mode, epoch=epoch, args=args):
        video, target, _, _ = batch
        if batch_idx == 0:
            print_or_log(video.shape, logger=logger)
        start_time = time.time()
        video = video.to(device)
        target = target.to(device)

        # Forward pass: get features, compute loss and accuracy
        output = model(video)
        loss = criterion(output, target)
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

        # initialize the optimizers
        optimizer.zero_grad()
        
        # compute the gradients
        loss.backward()
        
        # step
        optimizer.step()

        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
        metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))


def evaluate(args, model, criterion, data_loader, device, logger=None, writer=None, epoch=0):
    # Put model in eval model
    model.eval()
    
    # dicts to store labels and softmaxes
    softmaxes = {}
    labels = {}

    metric_logger = utils.MetricLoggerFinetune(delimiter="  ")
    header = 'Test:'
    torch.cuda.empty_cache()
    mode = f'{args.dataset}-{str(args.fold)}/{args.ckpt_epoch}/val'
    with torch.no_grad():
        for batch_idx, batch in metric_logger.log_every(data_loader, 10, header, logger, writer, mode, epoch=epoch, args=args):
            video, target, _, video_idx = batch
            if batch_idx == 0:
                print_or_log(video.shape, logger=logger)
            start_time = time.time()
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(video)
            BS = video.shape[0]
            output = output.view(BS, -1)
            loss = criterion(output, target)

            # Clip level accuracy
            batch_size = video.shape[0]
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            metric_logger.update(loss=loss.item(), lr=0)
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))

            # Video Level accuracy
            for j in range(len(video_idx)):
                video_id = video_idx[j].item()
                sm = output[j]
                label = target[j]

                # append it to video dict
                softmaxes.setdefault(video_id, []).append(sm)
                labels[video_id] = label
    
    # Get video acc@1 and acc@5 and output to tb writer
    video_acc1, video_acc5 = aggregrate_video_accuracy(
        softmaxes, labels, topk=(1, 5)
    )
    if writer is not None:
        writer.add_scalar(
            f'{mode}/vid_acc1/epoch', 
            video_acc1, 
            epoch
        )
        writer.add_scalar(
            f'{mode}/vid_acc5/epoch', 
            video_acc5, 
            epoch
        )

    # Log final results to terminal
    print_or_log(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5), logger=logger)
    print_or_log(f' * Vid Acc@1 {video_acc1.item():.3f} Video Acc@5 {video_acc5.item():.3f}', logger=logger)
    return video_acc1.item(), video_acc5.item()


def main(args, logger, writer):

    # Set distributed mode
    device = torch.device(args.device)

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Load model
    print_or_log("Loading model", logger=logger)
    model = utils.load_model(
        model_name=args.model, 
        vid_base_arch=args.vid_base_arch, 
        aud_base_arch=args.aud_base_arch, 
        pretrained=args.pretrained,
        num_classes=256,
        norm_feat=False,
        use_mlp=args.use_mlp,
        mlptype=args.mlptype,
        headcount=args.headcount
    )

    # Load model weights
    start = time.time()
    weight_path_type = type(args.weights_path)
    weight_path_not_none = args.weights_path != 'None' if weight_path_type == str else args.weights_path is not None
    if not args.pretrained and weight_path_not_none:
        print_or_log("Loading model weights", logger=logger)
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            model_weights = ckpt_dict["model"]
            print(f"Epoch checkpoint: {args.ckpt_epoch}", flush=True)
            utils.load_model_parameters(model, model_weights)
    print_or_log(f"Time to load model weights: {time.time() - start}", logger=logger)

    # Scale lr and batch-size
    train_batch_size = args.batch_size
    val_batch_size = 24
    if args.use_scaling:
        print_or_log(f"Using scaling: {args.world_size}", logger=logger)
        if args.optim_name in ['sgd']:
            args.head_lr = args.head_lr * args.world_size
        train_batch_size = args.batch_size * args.world_size
        val_batch_size = val_batch_size * args.world_size
    print_or_log(f"Head_lr: {args.head_lr}", logger=logger)
    print_or_log(f"Train BS: {train_batch_size}", logger=logger)
    print_or_log(f"Val BS: {val_batch_size}", logger=logger)

    # Add FC layer to model for fine-tuning or feature extracting
    model = load_model_finetune(
        args,
        model.video_network.base if args.model in ['av_gdt'] else model, 
        model_name=args.model, 
        num_ftrs=512 if args.vid_base_arch in ['r2plus1d_18', 'r2plus1d_34', 'r2plus1d_50'] else 2048,
        num_classes=NUM_CLASSES[args.dataset], 
    )

    # Create DataParallel model
    model_without_ddp = model
    if torch.cuda.device_count() > 1:
        model.to(device)
        if args.distributed and args.world_size > 8:
            print_or_log("Create model DDP", logger=logger)
            if args.sync_bn:
                print_or_log("Using SYNC BN", logger=logger)
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                broadcast_buffers=False
            )
        else:
            print_or_log("Create model data parallel", logger=logger)
            print_or_log(f"Let's use {torch.cuda.device_count()} GPUs!")
            model = torch.nn.DataParallel(model)
        model_without_ddp = model.module

    # get lr params
    params = []
    if args.feature_extract: # feature_extract only classifer
        print_or_log("Getting params for feature-extracting", logger=logger)
        for name, param in model_without_ddp.classifier.named_parameters():
            print(name, param.shape)
            params.append({'params': param, 'lr': args.head_lr, 'weight_decay': args.weight_decay})
    else: # finetune
        print_or_log("Getting params for finetuning", logger=logger)
        for name, param in model_without_ddp.classifier.named_parameters():
            print(name, param.shape)
            params.append({'params': param, 'lr': args.head_lr, 'weight_decay': args.weight_decay})
        for name, param in model_without_ddp.base.named_parameters():
            print(name, param.shape)
            params.append({'params': param, 'lr': args.base_lr, 'weight_decay': args.wd_base})

    print_or_log('\n===========Check Grad============', logger=logger)
    for name, param in model_without_ddp.named_parameters():
        print_or_log((name, param.requires_grad), logger=logger)
    print_or_log('=================================\n', logger=logger)

    print_or_log("Creating train and val dataset transforms", logger=logger)
    transform_train, transform_test, subsample = utils.get_transforms(args)

    # Loading Train data
    print_or_log("Loading training data", logger=logger)
    st = time.time()
    if args.dataset in ['ucf101', 'hmdb51']:
        cache_path = utils._get_cache_path(args.dataset, 'train', args.fold, args.clip_len, args.steps_bet_clips)
    else: # kinetics400
        cache_path = utils._get_cache_path(args.dataset, 'train', 1, args.clip_len, args.steps_bet_clips)
    if args.cache_dataset and os.path.exists(cache_path):
        print_or_log(f"Loading dataset_train from {cache_path}", logger=logger)
        dataset = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        dataset = utils.load_dataset(
            dataset_name=args.dataset,
            fold=args.fold,
            mode='train',
            frames_per_clip=args.clip_len,
            transforms=transform_train,
            subsample=subsample
        )
        if args.cache_dataset:
            print_or_log(f"Saving dataset_train to {cache_path}", logger=logger)
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset), cache_path)
    print_or_log(f"Took {time.time() - st}", logger=logger)

    # Loading Validation data
    print_or_log("Loading validation data", logger=logger)
    if args.dataset in ['ucf101', 'hmdb51']:
        cache_path = utils._get_cache_path(args.dataset, 'val', args.fold, args.clip_len, args.steps_bet_clips)
    else:
        cache_path = utils._get_cache_path(args.dataset, 'val', 1, args.clip_len, args.steps_bet_clips)
    if args.cache_dataset and os.path.exists(cache_path):
        print_or_log(f"Loading dataset_test from {cache_path}", logger=logger)
        dataset_test = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        dataset_test = utils.load_dataset(
            dataset_name=args.dataset,
            fold=args.fold,
            mode='val',
            frames_per_clip=args.clip_len,
            transforms=transform_test,
            subsample=subsample
        )
        if args.cache_dataset:
            print_or_log(f"Saving dataset_test to {cache_path}", logger=logger)
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test), cache_path)

    print_or_log("Creating data samplers", logger=logger)
    # random samples for train videos (Temporal Jittering)
    train_sampler = RandomClipSampler(dataset.video_clips, args.train_clips_per_video)
    # uniform samples for test videos
    test_sampler = UniformClipSampler(dataset_test.video_clips, args.val_clips_per_video)
    if args.distributed and args.world_size > 8:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    print_or_log("Creating data loaders", logger=logger)
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=train_batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True, 
        drop_last=True
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, 
        batch_size=val_batch_size,
        sampler=test_sampler, 
        num_workers=args.workers,
        pin_memory=True, 
        drop_last=True
    )

    # Set up training params
    criterion = torch.nn.CrossEntropyLoss()

    # linearly scale LR and set up optimizer
    print_or_log(f"Using {args.optim_name} optimization with lr: {args.head_lr}, mom: {args.momentum}, wd: {args.weight_decay}", logger=logger)
    optimizer = utils.load_optimizer(
        args.optim_name, 
        params,
        lr=args.head_lr, 
        momentum=args.momentum, 
        weight_decay=args.weight_decay
    )

    # Multi-step LR scheduler
    if args.use_scheduler:
        milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
        print_or_log(f"Num. of Epochs: {args.epochs}, Milestones: {milestones}", logger=logger)
        if args.lr_warmup_epochs > 0:
            print_or_log(f"Using LR multi-step scheduler with Gradual warmup for {args.lr_warmup_epochs} epochs", logger=logger)
            scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=milestones, 
                gamma=args.lr_gamma
            )
            lr_scheduler = GradualWarmupScheduler(
                optimizer, 
                multiplier=8, #args.world_size, 
                total_epoch=args.lr_warmup_epochs, 
                after_scheduler=scheduler_step
            )
        else: # no warmp, just multi-step
            print_or_log("Using LR multi-step scheduler w/out warmp", logger=logger)
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=milestones, 
                gamma=args.lr_gamma
            )
    else:
        lr_scheduler = None

    # Checkpointing
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    # Only perform evalaution
    if args.test_only:
        evaluate(
            args,
            model, 
            criterion, 
            data_loader_test, 
            device=device, 
            logger=logger,
            writer=writer,
            epoch=args.start_epoch
        )
        return

    start_time = time.time()
    best_vid_acc_1 = -1
    best_vid_acc_5 = -1
    best_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        print_or_log(f'Start training epoch: {epoch}', logger=logger)
        train_one_epoch(
            args,
            model, 
            criterion, 
            optimizer, 
            lr_scheduler, 
            data_loader,
            device, 
            epoch, 
            args.print_freq, 
            logger=logger,
            writer=writer,
        )
        print_or_log(f'Start evaluating epoch: {epoch}', logger=logger)
        if args.use_scheduler:
            lr_scheduler.step()
        eval_freq = 1 if args.dataset == 'kinetics400'  else 1
        if epoch % eval_freq == 0:
            vid_acc1, vid_acc5 = evaluate(
                args,
                model, 
                criterion, 
                data_loader_test, 
                device=device,
                logger=logger,
                writer=writer,
                epoch=epoch
            )
            if vid_acc1 > best_vid_acc_1:
                best_vid_acc_1 = vid_acc1
                best_vid_acc_5 = vid_acc5
                best_epoch = epoch
        if args.output_dir:
            print_or_log(f'Saving checkpoint to: {args.output_dir}', logger=logger)
            ckpt_freq = 1 if args.dataset == 'kinetics400' else 1
            utils.save_checkpoint(args, epoch, model, optimizer, lr_scheduler, ckpt_freq=ckpt_freq)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print_or_log(f'Training time {total_time_str}', logger=logger)
    return best_vid_acc_1, best_vid_acc_5, best_epoch


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
            'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Video Representation Learning')
    parser.register('type', 'bool', str2bool)

    ### DATA
    parser.add_argument(
        '--dataset', 
        default='ucf101', 
        help='name of dataset'
    )
    parser.add_argument(
        '--fold', 
        default='1,2,3', 
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '--clip-len', 
        default=30, 
        type=int, 
        metavar='N',
        help='number of frames per clip'
    )
    parser.add_argument(
        '--augtype',
        default=1,
        type=int, 
        help='augmentation type (default: 1)'
    )
    parser.add_argument(
        '--use-scale-jittering',
        default='False',
        type='bool', 
        help='scale jittering as augmentations'
    )
    parser.add_argument(
        '--colorjitter',
        default='False',
        type='bool', 
        help='scale jittering as augmentations'
    )
    parser.add_argument(
        '--steps-bet-clips', 
        default=1, 
        type=int, 
        metavar='N',
        help='number of steps between clips in video'
    )
    parser.add_argument(
        '--num-data-samples', 
        default=None, 
        type=int, 
        help='number of samples in dataset'
    )
    parser.add_argument(
        '--train-clips-per-video', 
        default=5, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for training'
    )
    parser.add_argument(
        '--val-clips-per-video', 
        default=10, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for testing'
    )
    parser.add_argument(
        "--cache-dataset",
        type='bool', 
        default='False',
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
    )
    parser.add_argument(
        '-j', '--workers', 
        default=0, 
        type=int, 
        metavar='N',
        help='number of data loading workers (default: 16)'
    )

    ### MODEL
    parser.add_argument(
        '--model', 
        default='av_gdt', 
        help='model',
        choices=['r2plus1d_18', 'av_gdt']
    )
    parser.add_argument(
        '--weights-path', 
        default='', 
        help='Path to weights file',
    )
    parser.add_argument(
        '--ckpt-epoch', 
        default='0',
        type=str,
        help='Epoch of model checkpoint',
    )
    parser.add_argument(
        '--vid-base-arch', 
        default='r2plus1d_18', 
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch', 
        default='resnet18', 
        help='Audio Base Arch for A-V model',
        choices=['resnet18', 'vgg_audio', 'resnet34', 'resnet50', 'resnet9']
    )
    parser.add_argument(
        "--pretrained",
        type='bool', 
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp', 
        default='True', 
        type='bool', 
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--mlptype',
        default=0,
        type=int,
        help='MLP type (default: 0)'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=1,
        help='how many heads each modality has'
    )
    parser.add_argument(
        "--sync-bn",
        type='bool',
        default='True',
        help="Use sync batch norm",
    )

    ### FINETUNE params
    parser.add_argument(
        "--feature-extract",
        type='bool', 
        default='False',
        help="Use model as feature extractor; if False, fientune entire model",
    )

    ### TRAINING
    parser.add_argument(
        '-b', '--batch-size', 
        default=24, 
        type=int
    )
    parser.add_argument(
        '--epochs', 
        default=15, 
        type=int, 
        metavar='N',
        help='number of total epochs to run'
    )
    parser.add_argument(
        "--use-scaling",
        type='bool', 
        default='False',
        help="Use LR scaling",
    )
    parser.add_argument(
        '--optim-name', 
        default='sgd', 
        type=str, 
        help='Name of optimizer',
        choices=['sgd', 'adam']
    )
    parser.add_argument(
        '--head-lr', 
        default=1e-2, 
        type=float, 
        help='initial learning rate'
    )
    parser.add_argument(
        '--base-lr', 
        default=1e-4, 
        type=float, 
        help='initial learning rate'
    )
    parser.add_argument(
        '--momentum', 
        default=0.9,
        type=float, 
        metavar='M',
        help='momentum'
    )
    parser.add_argument(
        '--wd', '--weight-decay', 
        default=1e-4, 
        type=float,
        metavar='W', 
        help='weight decay (default: 1e-4)',
        dest='weight_decay'
    )
    parser.add_argument(
        '--wd-base', 
        default=5e-3, 
        type=float,
    )
    parser.add_argument(
        "--use-scheduler",
        type='bool', 
        default='True',
        help="Use LR scheduler",
    )
    parser.add_argument(
        '--lr-warmup-epochs', 
        default=0, 
        type=int, 
        help='number of warmup epochs'
    )
    parser.add_argument(
        '--lr-milestones', 
        default='5,10', 
        type=str, 
        help='decrease lr on milestones (epochs)'
    )
    parser.add_argument(
        '--lr-gamma', 
        default=0.1, 
        type=float, 
        help='decrease lr by a factor of lr-gamma'
    )
    
    ### LOGGING
    parser.add_argument(
        '--print-freq', 
        default=10, 
        type=int, 
        help='print frequency'
    )
    parser.add_argument(
        '--output-dir', 
        default='.', 
        help='path where to save'
    )

    ### CHECKPOINTING
    parser.add_argument(
        '--resume', 
        default='', 
        help='resume from checkpoint'
    )
    parser.add_argument(
        '--start-epoch', 
        default=0, type=int, 
        metavar='N',
        help='start epoch'
    )
    parser.add_argument(
        "--test-only",
        type='bool', 
        default='False',
        help="Only test the model",
    )

    # distributed training parameters
    parser.add_argument(
        '--device', 
        default='cuda', 
        help='device'
    )
    parser.add_argument(
        '--distributed', 
        type='bool', 
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--dist-backend', 
        default='nccl', 
        type=str,
        help='distributed backend'
    )
    parser.add_argument(
        '--dist-url', 
        default='env://', 
        help='url used to set up distributed training'
    )
    parser.add_argument(
        '--world-size', 
        default=1, 
        type=int,
        help='number of distributed processes'
    )
    parser.add_argument(
        '--debug-slurm', 
        type='bool', 
        default='False',
        help="Debug SLURM",
    )
    parser.add_argument(
        '--local-rank',
        default=-1, 
        type=int,
        help='Local rank of node')
    parser.add_argument(
        '--master-port', 
        default=-1, 
        type=int,
        help='Master port of Job'
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    print(args)
    
    # set multi-processing start method
    import torch.multiprocessing as mp
    try:
        mp.set_start_method('forkserver')
    except RuntimeError:
        pass

    # Init distributed mode
    if torch.cuda.is_available():
        utils.init_distributed_mode(args)

    # Set up logger
    print(f'Distributed Mode: {args.distributed}')

    # Make output dir
    tbx_path = os.path.join(args.output_dir, 'tensorboard')
    if args.output_dir:
        utils.mkdir(args.output_dir)

    filename = 'logger.out'
    logger = utils.setup_logger(
        "Video_reader, classification",
        args.output_dir,
        True,
        logname=filename
    )

    # Set up tensorboard
    global_rank = int(os.environ['SLURM_PROCID']) if args.distributed else 0
    is_master = True if global_rank == 0 else False
    writer = None
    if is_master:
        writer = utils.setup_tbx(
            tbx_path,
            is_master
        )
        writer.add_text("namespace", repr(args))
    
    # Log version information
    print_or_log(args, logger=logger)
    print_or_log(f"torch version: {torch.__version__}", logger=logger)
    print_or_log(f"torchvision version: {torchvision.__version__}", logger=logger)
    
    # Run over different folds
    best_accs_1 = []
    best_accs_5 = []
    best_epochs = []
    folds = [int(fold) for fold in args.fold.split(',')]
    print(folds)
    if args.dataset in ['ucf101', 'hmdb51']:
        for fold in folds:
            args.fold = fold
            best_acc1, best_acc5, best_epoch = main(args, logger, writer)
            best_accs_1.append(best_acc1)
            best_accs_5.append(best_acc5)
            best_epochs.append(best_epoch)
        print(best_accs_1)
        print(best_acc5)
        avg_acc1 = np.mean(best_accs_1)
        avg_acc5 = np.mean(best_accs_5)
        print(f'3-Fold ({args.dataset}): Vid Acc@1 {avg_acc1:.3f}, Video Acc@5 {avg_acc5:.3f}')
    else:
        best_acc1, best_acc5, best_epoch = main(args, logger, writer)
        best_accs_1.append(best_acc1)
        best_accs_5.append(best_acc5)
        best_epochs.append(best_epoch)
        avg_acc1 = np.mean(best_accs_1)
        avg_acc5 = np.mean(best_accs_5)
        print(f'3-Fold ({args.dataset}): Vid Acc@1 {avg_acc1:.3f}, Video Acc@5 {avg_acc5:.3f}')
