import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import argparse
import datetime
import shutil
from pathlib import Path
from utils.config import get_config
from utils.optimizer import build_optimizer, build_scheduler
from utils.tools import AverageMeter, reduce_tensor, epoch_saving, load_checkpoint, generate_text, auto_resume_helper
from datasets.build import build_dataloader
from utils.logger import create_logger
import time
import numpy as np
import random
from apex import amp
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from datasets.blending import CutmixMixupBlending
from utils.config import get_config
from models import timerewarder
import wandb
import csv
import matplotlib.pyplot as plt

def parse_option():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-cfg', required=True, type=str, default='configs/k400/32_8.yaml')
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
    parser.add_argument('--output', type=str, default="exp")
    parser.add_argument('--resume', type=str)
    parser.add_argument('--pretrained', type=str)
    parser.add_argument('--only_test', action='store_true')
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--accumulation-steps', type=int)

    parser.add_argument("--local_rank", type=int, default=-1, help='local rank for DistributedDataParallel')
    args = parser.parse_args()

    config = get_config(args)

    return args, config


def main(config): 
    train_data, val_data, val_data_outdomain, common_data, rare_data, train_loader, val_loader, val_loader_outdomain, val_loader_common, val_loader_rare = build_dataloader(logger, config)
    if config.DATA.USE_ORDER:
        t_model = config.DATA.NUM_FRAMES_CLIP
        class_num = config.DATA.NUM_CLASSES
    else:
        t_model = config.DATA.NUM_FRAMES
        class_num = config.DATA.NUM_CLASSES

    if train_data is not None:
        text_labels = generate_text(train_data)
    else:
        text_labels = generate_text(val_data)
    if val_loader_outdomain is not None:
        text_labels_val = generate_text(val_data_outdomain)

    if config.TEST.ONLY_TEST_CLIP:
        import clip
        clipmodel, _ = clip.load(name="ViT-B/16", device="cuda")
        acc_clip, nofail_clip = test_clip(val_loader, text_labels, clipmodel, config)
        logger.info(f"Accuracy of the network on the {len(val_data)} in-domain test videos: {acc_clip:.1f}%"
                    f"Not all-fail rate: {nofail_clip:.1f}%")
        if val_loader_outdomain is not None:
            acc_clip_outdomain, nofail_outdomain = test_clip(val_loader_outdomain, text_labels_val, clipmodel, config)
            logger.info(f"Accuracy of the network on the {len(val_data_outdomain)} out-domain test videos: {acc_clip_outdomain:.1f}%"
                        f"Not all-fail rate: {nofail_outdomain:.1f}%")
        return
    
    model, _ = timerewarder.load(config.MODEL.PRETRAINED, config.MODEL.ARCH, 
                         device="cpu", jit=False, 
                         T=t_model, 
                         droppath=config.MODEL.DROP_PATH_RATE, 
                         mlp_droprate=config.MODEL.MLP_DROPOUT,
                         use_checkpoint=config.TRAIN.USE_CHECKPOINT, 
                         use_cache=config.MODEL.FIX_TEXT,
                         train_order=config.TRAIN.TRAIN_ORDER,
                         train_class=config.TRAIN.TRAIN_CLASS,
                         two_head=config.TRAIN.TWO_HEAD,
                         logger=logger,
                        )

    model = model.cuda()

    if config.TEST.ONLY_TEST:
        if not config.TEST.VISUALIZE:
            acc1, acc_order, avg_rank = validate(val_loader, text_labels, model, config)
            logger.info(f"Order accuracy of the network on the {len(val_data)} in-domain test videos: {acc_order:.1f}%")
            if acc1 is not None:
                logger.info(f"Accuracy of the network on the {len(val_data)} in-domain test videos: {acc1:.1f}%")
            if avg_rank is not None:
                if config.TRAIN.TWO_HEAD:
                    logger.info(f"Average rank of the correct class: {avg_rank:.1f}")
                else:
                    logger.info(f"Accuracy of identifying the wrong text: {avg_rank:.1f}")
            if val_loader_outdomain is not None:
                acc1_outdomain, acc_order_outdomain, avg_rank_outdomain = validate(val_loader_outdomain, text_labels_val, model, config, outdomain=True)
                logger.info(f"Order accuracy of the network on the {len(val_data_outdomain)} out-domain test videos: {acc_order_outdomain:.1f}%")
                if acc1_outdomain is not None:
                    logger.info(f"Accuracy of the network on the {len(val_data_outdomain)} out-domain test videos: {acc1_outdomain:.1f}%")
                if avg_rank_outdomain is not None:
                    if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank_outdomain:.1f}")
                    else:
                        logger.info(f"Accuracy of identifying the wrong text: {avg_rank_outdomain:.1f}")

            if val_loader_common is not None:
                acc1_common, acc_order_common, avg_rank_common = validate(val_loader_common, text_labels, model, config, outdomain=True)
                logger.info(f"Order accuracy of the network on the {len(common_data)} common test videos: {acc_order_common:.1f}%")
                if acc1_common is not None:
                    logger.info(f"Accuracy of the network on the {len(common_data)} common test videos: {acc1_common:.1f}%")
                if avg_rank_common is not None:
                    if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank_common:.1f}")
                    else:
                        logger.info(f"Accuracy of identifying the wrong text: {avg_rank_common:.1f}")
                acc1_rare, acc_order_rare, avg_rank_rare = validate(val_loader_rare, text_labels, model, config, outdomain=True)
                logger.info(f"Order accuracy of the network on the {len(rare_data)} rare test videos: {acc_order_rare:.1f}%")
                if acc1_rare is not None:
                    logger.info(f"Accuracy of the network on the {len(rare_data)} rare test videos: {acc1_rare:.1f}%")
                if avg_rank_rare is not None:
                    if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank_rare:.1f}")
                    else:
                        logger.info(f"Accuracy of identifying the wrong text: {avg_rank_rare:.1f}")
            return
        else:
            if config.TRAIN.TWO_HEAD:
                import clip
            visualize(val_loader, text_labels, model, config.DATA.LABEL_LIST, config, title=config.DATA.VAL_TITLE)
            if val_loader_outdomain is not None:
                visualize(val_loader_outdomain, text_labels_val, model, config.DATA.LABEL_LIST_VAL, config, title=config.DATA.VAL_TITLE + '_test_outdomain')
            if val_loader_common is not None:
                visualize(val_loader_common, text_labels, model, config.DATA.LABEL_LIST, config, title=config.DATA.VAL_TITLE + '_train_common')
                visualize(val_loader_rare, text_labels, model, config.DATA.LABEL_LIST, config, title=config.DATA.VAL_TITLE + '_train_rare')
            return

    mixup_fn = None
    if config.AUG.MIXUP > 0:
        criterion = SoftTargetCrossEntropy()
        mixup_fn = CutmixMixupBlending(num_classes=class_num, 
                                       smoothing=config.AUG.LABEL_SMOOTH, 
                                       mixup_alpha=config.AUG.MIXUP, 
                                       cutmix_alpha=config.AUG.CUTMIX, 
                                       switch_prob=config.AUG.MIXUP_SWITCH_PROB)
    elif config.AUG.LABEL_SMOOTH > 0:
        criterion = LabelSmoothingCrossEntropy(smoothing=config.AUG.LABEL_SMOOTH)
    else:
        criterion = nn.CrossEntropyLoss()
    
    optimizer = build_optimizer(config, model)
    lr_scheduler = build_scheduler(config, optimizer, len(train_loader))
    if config.TRAIN.OPT_LEVEL != 'O0':
        model, optimizer = amp.initialize(models=model, optimizers=optimizer, opt_level=config.TRAIN.OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=True)

    start_epoch, max_accuracy = 0, 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

    if config.MODEL.RESUME:
        start_epoch, max_accuracy = load_checkpoint(config, model.module, optimizer, lr_scheduler, logger)

    for epoch in range(start_epoch, config.TRAIN.EPOCHS):
        train_loader.sampler.set_epoch(epoch)                
        train_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, text_labels, config, mixup_fn)
        is_best = True
        if epoch % 2 == 0 or epoch == (config.TRAIN.EPOCHS - 1):
            test_loss, class_loss, order_loss, bin_metric = validate_progresspredict(val_loader, text_labels, model, config)
            
            logger.info(f"Test loss of the network on the {len(val_data)} in-domain test videos: {test_loss:.4f}")
            if dist.get_rank() == 0 and config.DATA.USE_WANDB:
                wandb.log({"test_loss": test_loss})
                if class_loss is not None:
                    wandb.log({"test_class_loss": class_loss})
                if order_loss is not None:
                    wandb.log({"test_order_loss": order_loss})
                for k, v in bin_metric.items():
                    wandb.log({k: v})
            if val_loader_outdomain is not None:
                test_loss_outdomain, class_loss_outdomain, order_loss_outdomain, bin_metric_outdomain = validate_progresspredict(val_loader_outdomain, text_labels_val, model, config)
                logger.info(f"Test loss of the network on the {len(val_data_outdomain)} out-domain test videos: {test_loss_outdomain:.4f}")
                if dist.get_rank() == 0 and config.DATA.USE_WANDB:
                    wandb.log({"test_loss_outdomain": test_loss_outdomain})
                    if class_loss_outdomain is not None:
                        wandb.log({"test_class_loss_outdomain": class_loss_outdomain})
                    if order_loss_outdomain is not None:
                        wandb.log({"test_order_loss_outdomain": order_loss_outdomain})
                    for k, v in bin_metric_outdomain.items():
                        k = k.replace('bin_', 'bin_outdomain_')
                        wandb.log({k: v})
        
        if 0:
            acc1, acc_order, avg_rank = validate(val_loader, text_labels, model, config)
            logger.info(f"Order accuracy of the network on the {len(val_data)} in-domain test videos: {acc_order:.1f}%")
            if acc1 is not None:
                logger.info(f"Accuracy of the network on the {len(val_data)} in-domain test videos: {acc1:.1f}%")
            if avg_rank is not None:
                if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank:.1f}")
                else:
                    logger.info(f"Accuracy of identifying the wrong text: {avg_rank:.1f}")
            if val_loader_outdomain is not None:
                acc1_outdomain, acc_order_outdomain, avg_rank_outdomain = validate(val_loader_outdomain, text_labels_val, model, config, outdomain=True)
                logger.info(f"Order accuracy of the network on the {len(val_data_outdomain)} out-domain test videos: {acc_order_outdomain:.1f}%")
                if acc1_outdomain is not None:
                    logger.info(f"Accuracy of the network on the {len(val_data_outdomain)} out-domain test videos: {acc1_outdomain:.1f}%")
                if avg_rank_outdomain is not None:
                    if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank_outdomain:.1f}")
                    else:
                        logger.info(f"Accuracy of identifying the wrong text: {avg_rank_outdomain:.1f}")
            if val_loader_common is not None:
                acc1_common, acc_order_common, avg_rank_common = validate(val_loader_common, text_labels, model, config, outdomain=True)
                logger.info(f"Order accuracy of the network on the {len(common_data)} common test videos: {acc_order_common:.1f}%")
                if acc1_common is not None:
                    logger.info(f"Accuracy of the network on the {len(common_data)} common test videos: {acc1_common:.1f}%")
                if avg_rank_common is not None:
                    if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank_common:.1f}")
                    else:
                        logger.info(f"Accuracy of identifying the wrong text: {avg_rank_common:.1f}")
                acc1_rare, acc_order_rare, avg_rank_rare = validate(val_loader_rare, text_labels, model, config, outdomain=True)
                logger.info(f"Order accuracy of the network on the {len(rare_data)} rare test videos: {acc_order_rare:.1f}%")
                if acc1_rare is not None:
                    logger.info(f"Accuracy of the network on the {len(rare_data)} rare test videos: {acc1_rare:.1f}%")
                if avg_rank_rare is not None:
                    if config.TRAIN.TWO_HEAD:
                        logger.info(f"Average rank of the correct class: {avg_rank_rare:.1f}")
                    else:
                        logger.info(f"Accuracy of identifying the wrong text: {avg_rank_rare:.1f}")
            
            if acc1 is None or not config.TRAIN.TWO_HEAD:
                is_best = acc_order > max_accuracy
                max_accuracy = max(max_accuracy, acc_order)
                logger.info(f'Max order test accuracy: {max_accuracy:.2f}%')
            else:
                is_best = acc1 > max_accuracy
                max_accuracy = max(max_accuracy, acc1)
                logger.info(f'Max test accuracy: {max_accuracy:.2f}%')
            if dist.get_rank() == 0 and config.DATA.USE_WANDB:
                if val_loader_outdomain is not None:
                    if acc1 is None:
                        wandb.log({"valset_indomain_order_acc": acc_order, "valset_outdomain_order_acc": acc_order_outdomain})
                    else:
                        if config.TRAIN.TWO_HEAD:
                            wandb.log({"valset_indomain_acc": acc1, "valset_outdomain_acc": acc1_outdomain,
                            "valset_indomain_order_acc": acc_order, "valset_outdomain_order_acc": acc_order_outdomain,
                            "valset_indomain_avg_rank": avg_rank, "valset_outdomain_avg_rank": avg_rank_outdomain})
                        else:
                            wandb.log({"valset_indomain_acc": acc1, "valset_outdomain_acc": acc1_outdomain,
                            "valset_indomain_order_acc": acc_order, "valset_outdomain_order_acc": acc_order_outdomain,
                            "valset_indomain_wrongtext_acc": avg_rank, "valset_outdomain_wrongtext_acc": avg_rank_outdomain})
                else:
                    if acc1 is None:
                        wandb.log({"valset_order_acc": acc_order})
                    else:
                        wandb.log({"valset_acc": acc1, "valset_order_acc": acc_order, "valset_avg_rank": avg_rank})
                if val_loader_common is not None:
                    if acc1_common is None:
                        wandb.log({"trainset_common_order_acc": acc_order_common, "trainset_rare_order_acc": acc_order_rare})
                    else:
                        if config.TRAIN.TWO_HEAD:
                            wandb.log({"trainset_common_acc": acc1_common, "trainset_rare_acc": acc1_rare,
                            "trainset_common_order_acc": acc_order_common, "trainset_rare_order_acc": acc_order_rare,
                            "trainset_common_avg_rank": avg_rank_common, "trainset_rare_avg_rank": avg_rank_rare})
                        else:
                            wandb.log({"trainset_common_acc": acc1_common, "trainset_rare_acc": acc1_rare,
                            "trainset_common_order_acc": acc_order_common, "trainset_rare_order_acc": acc_order_rare,
                            "trainset_common_wrongtext_acc": avg_rank_common, "trainset_rare_wrongtext_acc": avg_rank_rare})
            
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)) and epoch > 0:
            epoch_saving(config, epoch, model.module, max_accuracy, optimizer, lr_scheduler, logger, config.OUTPUT, is_best)

@torch.no_grad()
def validate_progresspredict(val_loader, text_labels, model, config, outdomain=False):
    model.eval()
    test_loss_meter = AverageMeter()
    classification_loss_meter = AverageMeter()
    regression_loss_meter = AverageMeter()
    intervals = [
                    (-1.0, -0.5),
                    (-0.5, -0.25),
                    (-0.25, -0.125),
                    (-0.125, -0.0625),
                    (-0.0625, 0),
                    (0, 0.0625),
                    (0.0625, 0.125),
                    (0.125, 0.25),
                    (0.25, 0.5),
                    (0.5, 1.0),
                ]
    bin_metrics_meter = {
        f"bin_[{intervals[i][0]},{intervals[i][1]}]_sign_error": AverageMeter()
        for i in range(len(intervals))
    }
    bin_metrics_meter.update({
        f"bin_[{intervals[i][0]},{intervals[i][1]}]_mae": AverageMeter()
        for i in range(len(intervals))
    })
    texts = text_labels.cuda(non_blocking=True)
    if texts.shape[0] == 1:
        texts = texts.view(1, -1)
    for idx, batch_data in enumerate(val_loader):
        images = batch_data["imgs"].cuda(non_blocking=True)
        if 'progress' in batch_data:
            progress = batch_data["progress"].cuda(non_blocking=True)
        else:
            progress = None
        label_id = batch_data["label"].cuda(non_blocking=True)
        label_id_neg = batch_data["neg_labels"].cuda(non_blocking=True)
        label_id = label_id.reshape(-1)
        # total_loss, classification_loss, regression_loss, logits_order, bin_metrics = model(images, texts, label_id, config.DATA.USE_ORDER, config.DATA.NUM_CLIPS, label_id_neg = label_id_neg, progress = progress, show_bin=True)
        total_loss, classification_loss, regression_loss, logits_order = model(images, texts, label_id, config.DATA.USE_ORDER, config.DATA.NUM_CLIPS, label_id_neg = label_id_neg, progress = progress, show_bin=False)
        # for k, v in bin_metrics.items():
        #     bin_metrics_meter[k].update(v, len(label_id))

        test_loss_meter.update(total_loss.item(), len(label_id))
        if classification_loss is not None:
            classification_loss_meter.update(classification_loss.item(), len(label_id))
        if regression_loss is not None:
            regression_loss_meter.update(regression_loss.item(), len(label_id))
        bin_metric_avg ={key: meter.avg for key, meter in bin_metrics_meter.items()}
    return test_loss_meter.avg, classification_loss_meter.avg, regression_loss_meter.avg, bin_metric_avg


def train_one_epoch(epoch, model, criterion, optimizer, lr_scheduler, train_loader, text_labels, config, mixup_fn):
    model.train()
    optimizer.zero_grad()
    
    num_steps = len(train_loader)
    batch_time = AverageMeter()
    tot_loss_meter = AverageMeter()
    if config.DATA.USE_ORDER:
        if config.TRAIN.TWO_HEAD:
            if config.TRAIN.TRAIN_CLASS:
                class_loss_meter, class_acc_meter = AverageMeter(), AverageMeter()
            if config.TRAIN.TRAIN_ORDER:
                order_loss_meter, order_acc_meter = AverageMeter(), AverageMeter()
        else:
            order_loss_meter, order_acc_meter, class_loss_meter, class_acc_meter = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    
    start = time.time()
    end = time.time()
    
    texts = text_labels.cuda(non_blocking=True)
    
    for idx, batch_data in enumerate(train_loader):
        images = batch_data["imgs"].cuda(non_blocking=True)
        if 'progress' in batch_data:
            progress = batch_data["progress"].cuda(non_blocking=True)
        else:
            progress = None
        label_id = batch_data["label"].cuda(non_blocking=True)
        label_id_neg = batch_data["neg_labels"].cuda(non_blocking=True)
        # print('images:', images.shape, 'label_id:', label_id.shape, 'label_id_neg:', label_id_neg.shape)
        label_id = label_id.reshape(-1)
        # images = images.view((-1,config.DATA.NUM_FRAMES,3)+images.size()[-2:])
        if config.DATA.USE_ORDER:
            if mixup_fn is not None:
                label_id = label_id.repeat(config.DATA.NUM_CLIPS * (config.DATA.NUM_CLIPS - 1) // 2, 1)
            else:
                label_id = label_id.repeat(config.DATA.NUM_CLIPS * (config.DATA.NUM_CLIPS - 1) // 2)
        
        if mixup_fn is not None:
            images, label_id = mixup_fn(images, label_id)

        if texts.shape[0] == 1:
            texts = texts.view(1, -1)

        output = model(images, texts, label_id, config.DATA.USE_ORDER, config.DATA.NUM_CLIPS, label_id_neg = label_id_neg, progress = progress)
        if config.DATA.USE_ORDER:
            if config.TRAIN.TWO_HEAD:
                logits_class, logits_order = output
                # class_loss = criterion(logits_class, label_id)
                total_loss = 0
                if logits_class is not None:
                    class_loss = criterion(logits_class, torch.zeros(logits_class.shape[0], dtype=torch.int64).cuda())
                    class_loss = class_loss / config.TRAIN.ACCUMULATION_STEPS
                    total_loss += class_loss
                    class_acc = (logits_class.argmax(dim=-1) == 0).float().mean()
                if logits_order is not None:
                    order_loss = criterion(logits_order, torch.zeros(logits_order.shape[0], dtype=torch.int64).cuda()) * config.TRAIN.LOSS_RATIO
                    order_loss = order_loss / config.TRAIN.ACCUMULATION_STEPS
                    total_loss += order_loss
                    order_acc = (logits_order.argmax(dim=-1) == 0).float().mean()
            else:
                if len(output) == 2:
                    logits_wrong_order, logits_order = output
                    order_loss = criterion(logits_order, torch.zeros(logits_order.shape[0], dtype=torch.int64).cuda())
                    order_acc = (logits_order.argmax(dim=-1) == 0).float().mean()
                    # print(order_loss)
                    if logits_wrong_order is None:
                        total_loss = order_loss
                        class_loss = None
                        class_acc = None
                    else:
                        logits_wrong_order = logits_wrong_order.softmax(dim=-1)
                        # class_loss = torch.sum(logits_wrong_order * torch.log(logits_wrong_order + 1e-8)) / config.TRAIN.LOSS_RATIO
                        class_loss = criterion(logits_wrong_order, 2 * torch.ones(logits_wrong_order.shape[0], dtype=torch.int64).cuda()) * config.TRAIN.LOSS_RATIO
                        # print(class_loss)
                        total_loss = order_loss + class_loss
                        # print(logits_wrong_order)
                        class_acc = (logits_wrong_order.argmax(dim=-1) == 2).float().mean()
                else:
                    total_loss, classification_loss, regression_loss, logits_order = output
                    order_loss = regression_loss
                    order_acc = None
                    class_loss = classification_loss
                    class_acc = None
        else:
            total_loss = criterion(output, label_id)
            total_loss = total_loss / config.TRAIN.ACCUMULATION_STEPS


        if config.TRAIN.ACCUMULATION_STEPS == 1:
            optimizer.zero_grad()
        if config.TRAIN.OPT_LEVEL != 'O0':
            with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            total_loss.backward()
        if config.TRAIN.ACCUMULATION_STEPS > 1:
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()
        
        tot_loss_meter.update(total_loss.item(), len(label_id))
        if config.DATA.USE_ORDER:
            if class_loss is not None:
                class_loss_meter.update(class_loss.item(), len(label_id))
            if class_acc is not None:
                class_acc_meter.update(class_acc.item() * 100, len(label_id))
            if order_loss is not None:
                order_loss_meter.update(order_loss.item(), len(label_id))
            if order_acc is not None:
                order_acc_meter.update(order_acc.item() * 100, len(label_id))
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            if not config.DATA.USE_ORDER:
                logger.info(
                    f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                    f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.9f}\t'
                    f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                    f'tot_loss {tot_loss_meter.val:.4f} ({tot_loss_meter.avg:.4f})\t'
                    f'mem {memory_used:.0f}MB')
                if dist.get_rank() == 0 and config.DATA.USE_WANDB:
                    wandb.log({"train_tot_loss": tot_loss_meter.avg, "epoch":epoch})
            else:
                logger.info(
                    f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                    f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.9f}\t'
                    f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                    f'tot_loss {tot_loss_meter.val:.4f} ({tot_loss_meter.avg:.4f})\t'
                    f'mem {memory_used:.0f}MB')
                if class_loss is not None:
                    logger.info(
                        f'class_loss {class_loss_meter.val:.4f} ({class_loss_meter.avg:.4f})\t')
                if class_acc is not None:
                    logger.info(
                        f'class_acc {class_acc_meter.val:.4f} ({class_acc_meter.avg:.4f})\t')
                if order_loss is not None:
                    logger.info(
                        f'order_loss {order_loss_meter.val:.4f} ({order_loss_meter.avg:.4f})\t')
                if order_acc is not None:
                    logger.info(
                        f'order_acc {order_acc_meter.val:.4f} ({order_acc_meter.avg:.4f})\t')

    epoch_time = time.time() - start
    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
    if dist.get_rank() == 0 and config.DATA.USE_WANDB:
        wandb.log({"epoch_tot_loss": tot_loss_meter.avg})
        if class_loss is not None:
            wandb.log({"epoch_class_loss": class_loss_meter.avg})
        if class_acc is not None:
            wandb.log({"epoch_class_acc": class_acc_meter.avg})
        if order_loss is not None:
            wandb.log({"epoch_order_loss": order_loss_meter.avg})
        if order_acc is not None:
            wandb.log({"epoch_order_acc": order_acc_meter.avg})

@torch.no_grad()
def test_clip(val_loader, text_labels, clipmodel, config):
    text_labels = text_labels.cuda()
    acc_meter = AverageMeter()
    nofail_meter = AverageMeter()
    for idx, batch_data in enumerate(val_loader):
        image = batch_data["imgs"]
        label_id_neg = batch_data["neg_labels"]
        image_input = image.cuda(non_blocking=True)
        image_input = image_input[:,:,0]
        for i in range(image_input.shape[0]):
            text_input_clip = text_labels[label_id_neg[i]]
            logits_per_image, _ = clipmodel(image_input[i], text_input_clip)
            logits_per_image = logits_per_image.softmax(dim=-1)
            _, indice_1_clip = logits_per_image.topk(1, dim=-1)
            acc = torch.sum(indice_1_clip == 0).item()
            n = indice_1_clip.shape[0]
            acc_meter.update(acc / n * 100, n)
            if acc == 0:
                nofail_meter.update(0, 1)
            else:
                nofail_meter.update(100, 1)
        logger.info(f'Test: [{idx}/{len(val_loader)}]\t'
                    f'acc {acc_meter.avg:.4f})\t'
                    f'nofail {nofail_meter.avg:.4f})\t')
    return acc_meter.avg, nofail_meter.avg

@torch.no_grad()
def visualize_clip(image, label_ids ,text_inputs, clipmodel, config):
    n, t, c, h, w = image.shape
    image = image.reshape(n*t, c, h, w)
    image_input = image.cuda(non_blocking=True)
    text_input = text_inputs[label_ids]
    _, logits_per_text = clipmodel(image_input, text_input)
    print('logits_per_text:', logits_per_text.shape)
    logits_per_text = logits_per_text.reshape(n, t, -1)

    return logits_per_text

@torch.no_grad()
def visualize(val_loader, text_labels, model, label_list, config, title, hard_neg=True):
    model.eval()
    visualize_dir = os.path.join(config.OUTPUT, 'visualize')
    if not os.path.exists(visualize_dir):
        os.makedirs(visualize_dir)
    with open(label_list) as csvfile:
        reader = csv.reader(csvfile)
        label_dict = {rows[0]: rows[1] for rows in reader}
    text_inputs = text_labels.cuda()
    # first_frame_dict = {0:0, 1:0, 2:0, 3:0, 4:1, 5:1, 6:1, 7:2, 8:2, 9:3}
    # second_frame_dict = {0:1, 1:2, 2:3, 3:4, 4:2, 5:3, 6:4, 7:3, 8:4, 9:4}
    first_frame_dict, second_frame_dict = torch.triu_indices(config.DATA.NUM_CLIPS_VAL, config.DATA.NUM_CLIPS_VAL, offset=1)
    if config.TRAIN.TWO_HEAD:
        pair_dict = {0:'(1,2)', 1:'(1,3)', 2:'(1,4)', 3:'(1,5)', 4:'(2,3)', 5:'(2,4)', 6:'(2,5)', 7:'(3,4)', 8:'(3,5)', 9:'(4,5)'}
        visualize_dir_suc = os.path.join(visualize_dir, 'success')
        visualize_dir_fail = os.path.join(visualize_dir, 'fail')
        visualizie_dir_all_fail = os.path.join(visualize_dir, 'all_fail')
        if not os.path.exists(visualize_dir_suc):
            os.makedirs(visualize_dir_suc)
        if not os.path.exists(visualize_dir_fail):
            os.makedirs(visualize_dir_fail)
        if not os.path.exists(visualizie_dir_all_fail):
            os.makedirs(visualizie_dir_all_fail)
        
        clipmodel, _ = clip.load(name="ViT-B/16", device="cuda")
        for idx, batch_data in enumerate(val_loader):
            if 1:
                image = batch_data["imgs"]
                label_id = batch_data["label"]
                label_id_neg = batch_data["neg_labels"]
                batch_size = image.shape[0]
                label_id = label_id.reshape(-1)
                if config.DATA.USE_ORDER:
                    label_id = label_id.repeat(config.DATA.NUM_CLIPS_VAL * (config.DATA.NUM_CLIPS_VAL - 1) // 2)

                label_id = label_id.cuda(non_blocking=True)
                image_input = image.cuda(non_blocking=True)
                label_id_neg = label_id_neg.cuda(non_blocking=True)

                if config.TRAIN.OPT_LEVEL == 'O2':
                    image_input = image_input.half()
                with torch.no_grad():
                    output = model(image_input, text_inputs, label_id, config.DATA.USE_ORDER, config.DATA.NUM_CLIPS_VAL, label_id_neg = label_id_neg)
                logits_class, logits_order = output
                logits_class = logits_class.softmax(dim=-1)
                logits_order = logits_order.softmax(dim=-1)
                _, indice_1 = logits_class.topk(1, dim=-1)
                indice_1 = indice_1.reshape(-1, batch_size, 1)

                for i in range(batch_size):
                    image_input_clip = image_input[i,:,0]
                    text_input_clip = text_inputs[label_id_neg[i]]
                    clip_success = True
                    with torch.no_grad():
                        logits_per_image, _ = clipmodel(image_input_clip, text_input_clip)
                        logits_per_image = logits_per_image.softmax(dim=-1)
                        _, indice_1_clip = logits_per_image.topk(1, dim=-1)
                        if torch.sum(indice_1_clip == 0) == 0:
                            clip_success = False
                        indice_1_clip = label_id_neg[i, indice_1_clip]

                    label = label_dict[str(label_id[i].item())]
                    indice = indice_1[:, i, 0]
                    all_correct = torch.sum(indice == 0) == indice.shape[0]
                    all_fail = torch.sum(indice == 0) == 0
                    # if all_fail and clip_success:
                    if 1:
                        pre = label_id_neg[i, indice]
                        catimage = image[i]
                        n, t, c, h, w = catimage.shape
                        catimage = catimage.permute(0, 3, 1, 4, 2)
                        catimage = catimage.reshape(n*h, t*w, c)
                        catimage = catimage.cpu().numpy()
                        img_norm_cfg = dict(
                            mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
                        catimage = catimage * np.array(img_norm_cfg['std']) + np.array(img_norm_cfg['mean'])
                        catimage = catimage / 255.0
                        plt.figure(figsize=(15, 8))
                        plt.subplot(2,3,1)
                        plt.axis('off')
                        plt.text(0, 1, 'label:' + label)
                        for l in range(indice_1_clip.shape[0]):
                            if indice_1_clip[l] == label_id[i]:
                                plt.text(0, 1 - 0.2*(l+1), 'pre_clip_' + str(l + 1) + ':' + label_dict[str(indice_1_clip[l].item())], color = 'green')
                            else:
                                plt.text(0, 1 - 0.2*(l+1), 'pre_clip_' + str(l + 1) + ':' + label_dict[str(indice_1_clip[l].item())], color = 'red')
                        plt.subplot(2,3,2)
                        plt.axis('off')
                        plt.imshow(catimage)
                        plt.subplot(2,3,3)
                        plt.axis('off')
                        plt.text(0, 1, 'label:' + label)
                        for j in range(indice.shape[0]):
                            if pre[j] == label_id[i]:
                                plt.text(0, 1 - 0.1*(j+1), 'pre_' + pair_dict[j] + ':' + label_dict[str(pre[j].item())], color = 'green')
                            else:
                                plt.text(0, 1 - 0.1*(j+1), 'pre_' + pair_dict[j] + ':' + label_dict[str(pre[j].item())], color = 'red')
                        plt.subplot(2,3,4)
                        # print('logit_order:', logits_order)
                        logits_num = logits_order.shape[0]
                        logits_num = logits_num // 2
                        logits_inorder = logits_order[:logits_num]
                        logits_inverse = logits_order[logits_num:]
                        logits_inorder = logits_inorder.reshape(-1, batch_size, 2)
                        logits_inverse = logits_inverse.reshape(-1, batch_size, 2)
                        logits_inorder = logits_inorder[:, i, 0]
                        logits_inverse = logits_inverse[:, i, 1]
                        matrix = np.zeros((5, 5))
                        for j in range(10):
                            matrix[first_frame_dict[j]][second_frame_dict[j]] = logits_inorder[j].item()
                            matrix[second_frame_dict[j]][first_frame_dict[j]] = logits_inverse[j].item()
                        plt.imshow(matrix, vmin=0, vmax=1)
                        plt.xlabel('Second Clip Number')
                        plt.ylabel('First Clip Number')
                        for m in range(5):
                            for n in range(5):
                                if m != n:
                                    plt.text(n, m, f'{matrix[m][n]:.2f}', ha='center', va='center', color='white')
                                else:
                                    plt.text(n, m, '--', ha='center', va='center', color='white')
                        plt.colorbar()
                        plt.title('Order_' + label_dict[str(label_id[i].item())])

                        plt.subplot(2,3,5)
                        logits_class = logits_class.reshape(-1, batch_size, logits_class.shape[-1])
                        score_class = logits_class[:, i, 0]
                        class_matrix = np.zeros((5, 5))
                        for j in range(10):
                            class_matrix[first_frame_dict[j]][second_frame_dict[j]] = score_class[j].item()
                        plt.imshow(class_matrix, vmin=0, vmax=1)
                        plt.xlabel('Second Clip Number')
                        plt.ylabel('First Clip Number')
                        plt.xticks(range(5), range(1, 6))
                        plt.yticks(range(5), range(1, 6))
                        for m in range(5):
                            for n in range(5):
                                if m < n:
                                    plt.text(n, m, f'{class_matrix[m][n]:.2f}', ha='center', va='center', color='white')
                                else:
                                    plt.text(n, m, '--', ha='center', va='center', color='white')
                        plt.colorbar()
                        plt.title('Class_' + label_dict[str(label_id[i].item())])

                        plt.subplot(2,3,6)
                        score_multi = matrix * class_matrix
                        plt.imshow(score_multi, vmin=0, vmax=1)
                        plt.xlabel('Second Clip Number')
                        plt.ylabel('First Clip Number')
                        plt.xticks(range(5), range(1, 6))
                        plt.yticks(range(5), range(1, 6))
                        for m in range(5):
                            for n in range(5):
                                if m < n:
                                    plt.text(n, m, f'{score_multi[m][n]:.2f}', ha='center', va='center', color='white')
                                else:
                                    plt.text(n, m, '--', ha='center', va='center', color='white')
                        plt.colorbar()
                        plt.title('Multi_' + label_dict[str(label_id[i].item())])

                        plt.tight_layout()
                        if all_correct:
                            plt.savefig(os.path.join(visualize_dir_suc, f'{title}_{idx}_{i}.jpg'))
                        else:
                            plt.savefig(os.path.join(visualize_dir_fail, f'{title}_{idx}_{i}.jpg'))
                        if all_fail:
                            plt.savefig(os.path.join(visualizie_dir_all_fail, f'{title}_{idx}_{i}.jpg'))
                        plt.clf()
            else:
                print(title + ' finished!')
                break
    else:
        save_dir = os.path.join(visualize_dir, title)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        text_inputs = text_labels.cuda()
        if 'outdomain' in title:
            num_class = config.DATA.NUM_CLASSES_VAL
        else:
            num_class = config.DATA.NUM_CLASSES
        print('hard_neg:', hard_neg)
        if not hard_neg:
            for idx, batch_data in enumerate(val_loader):
                if idx <= 3 or ('sthv2' not in title):
                    image = batch_data["imgs"]
                    label_id = batch_data["label"]
                    label_id_neg = batch_data["neg_labels"]
                    batch_size = image.shape[0]
                    label_id = label_id.reshape(-1)
                    label_id = label_id.repeat(config.DATA.NUM_CLIPS_VAL * (config.DATA.NUM_CLIPS_VAL - 1) // 2)
                    label_id = label_id.cuda(non_blocking=True)
                    label_id_fake = (label_id + 11) % num_class
                    label_id_neg_fake = (label_id_neg + 11) % num_class
                    image_input = image.cuda(non_blocking=True)
                    if config.TRAIN.OPT_LEVEL == 'O2':
                        image_input = image_input.half()
                    with torch.no_grad():
                        output = model(image_input, text_inputs, label_id, config.DATA.USE_ORDER, config.DATA.NUM_CLIPS_VAL, label_id_neg = label_id_neg)
                        output2 = model(image_input, text_inputs, label_id_fake , config.DATA.USE_ORDER, config.DATA.NUM_CLIPS_VAL, label_id_neg = label_id_neg_fake)
                    _, logits_order = output
                    _, logits_order_fake = output2
                    logits_order = logits_order.softmax(dim=-1)
                    logits_order_fake = logits_order_fake.softmax(dim=-1)
                    for i in range(batch_size):
                        catimage = image[i]
                        label = label_dict[str(label_id[i].item())]
                        label_fake = label_dict[str(label_id_fake[i].item())]
                        n, t, c, h, w = catimage.shape
                        catimage = catimage.permute(0, 3, 1, 4, 2)
                        catimage = catimage.reshape(n*h, t*w, c)
                        catimage = catimage.cpu().numpy()
                        img_norm_cfg = dict(
                            mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
                        catimage = catimage * np.array(img_norm_cfg['std']) + np.array(img_norm_cfg['mean'])
                        catimage = catimage / 255.0
                        plt.figure(figsize=(18, 6))
                        plt.subplot(1,4,1)
                        plt.axis('off')
                        plt.imshow(catimage)
                        plt.title('label:' + label)
                        plt.subplot(1,4,2)
                        logits_num = logits_order.shape[0]
                        logits_num = logits_num // 2
                        logits_inorder = logits_order[:logits_num]
                        logits_inverse = logits_order[logits_num:]
                        logits_inorder = logits_inorder.reshape(-1, batch_size, logits_inorder.shape[-1])
                        logits_inverse = logits_inverse.reshape(-1, batch_size, logits_inorder.shape[-1])
                        logits_inorder = logits_inorder[:, i, 0]
                        logits_inverse = logits_inverse[:, i, 1]
                        matrix = np.zeros((config.DATA.NUM_CLIPS_VAL, config.DATA.NUM_CLIPS_VAL))
                        for j in range(10):
                            matrix[first_frame_dict[j]][second_frame_dict[j]] = logits_inorder[j].item()
                            matrix[second_frame_dict[j]][first_frame_dict[j]] = logits_inverse[j].item()
                        plt.imshow(matrix, vmin=0, vmax=1)
                        plt.xlabel('Second Clip Number')
                        plt.ylabel('First Clip Number')
                        plt.xticks(range(5), range(1, 6))
                        plt.yticks(range(5), range(1, 6))
                        for m in range(5):
                            for n in range(5):
                                if m != n:
                                    plt.text(n, m, f'{matrix[m][n]:.2f}', ha='center', va='center', color='white')
                                else:
                                    plt.text(n, m, '--', ha='center', va='center', color='white')
                        plt.colorbar()
                        plt.title('Order prediction score')
                        plt.subplot(1,4,3)
                        logits_inorder_fake = logits_order_fake[:logits_num]
                        logits_inverse_fake = logits_order_fake[logits_num:]
                        logits_inorder_fake = logits_inorder_fake.reshape(-1, batch_size, logits_inorder_fake.shape[-1])
                        logits_inverse_fake = logits_inverse_fake.reshape(-1, batch_size, logits_inorder_fake.shape[-1])
                        logits_inorder_fake_1 = logits_inorder_fake[:, i, 0]
                        logits_inverse_fake_1 = logits_inverse_fake[:, i, 1]
                        matrix_fake = np.zeros((5, 5))
                        for j in range(10):
                            matrix_fake[first_frame_dict[j]][second_frame_dict[j]] = logits_inorder_fake_1[j].item()
                            matrix_fake[second_frame_dict[j]][first_frame_dict[j]] = logits_inverse_fake_1[j].item()
                        plt.imshow(matrix_fake, vmin=0, vmax=1)
                        plt.xlabel('Second Clip Number')
                        plt.ylabel('First Clip Number')
                        plt.xticks(range(5), range(1, 6))
                        plt.yticks(range(5), range(1, 6))
                        for m in range(5):
                            for n in range(5):
                                if m != n:
                                    plt.text(n, m, f'{matrix_fake[m][n]:.2f}', ha='center', va='center', color='white')
                                else:
                                    plt.text(n, m, '--', ha='center', va='center', color='white')
                        plt.colorbar()
                        plt.title('\''+label_fake+' (logit 0)\'')
                        
                        plt.subplot(1,4,4)
                        logits_inorder_fake_3 = logits_inorder_fake[:, i, 2]
                        logits_inverse_fake_3 = logits_inverse_fake[:, i, 2]
                        matrix_fake_3 = np.zeros((5, 5))
                        for j in range(10):
                            matrix_fake_3[first_frame_dict[j]][second_frame_dict[j]] = logits_inorder_fake_3[j].item()
                            matrix_fake_3[second_frame_dict[j]][first_frame_dict[j]] = logits_inverse_fake_3[j].item()
                        plt.imshow(matrix_fake_3, vmin=0, vmax=1)
                        plt.xlabel('Second Clip Number')
                        plt.ylabel('First Clip Number')
                        plt.xticks(range(5), range(1, 6))
                        plt.yticks(range(5), range(1, 6))
                        for m in range(5):
                            for n in range(5):
                                if m != n:
                                    plt.text(n, m, f'{matrix_fake_3[m][n]:.2f}', ha='center', va='center', color='white')
                                else:
                                    plt.text(n, m, '--', ha='center', va='center', color='white')
                        plt.colorbar()
                        plt.title('\''+label_fake+' (logit 2)\'')

                        plt.tight_layout()
                        plt.savefig(os.path.join(save_dir, f'{label}_{idx}_{i}.jpg'))
                        plt.clf()
                else:
                    print(title + ' finished!')
                    break
        else:
            import clip
            clipmodel, _ = clip.load(name="ViT-B/16", device="cuda")
            for idx, batch_data in enumerate(val_loader):
                if idx <= 3 or ('sthv2' not in title):
                    image = batch_data["imgs"]
                    batch_size = image.shape[0]
                    label_id_neg = batch_data["neg_labels"]
                    print('label_id_neg:', label_id_neg.shape)
                    image_input = image.cuda(non_blocking=True)
                    logits_matrix = model.visualize(image_input, text_inputs, num_clips=config.DATA.NUM_CLIPS_VAL, label_id_neg=label_id_neg)
                    logits_matrix = logits_matrix.cpu().numpy()
                    for i in range(batch_size):
                        catimage = image[i]
                        label = label_dict[str(label_id_neg[i][0].item())]
                        label_ids = label_id_neg[i]
                        clip_logits = visualize_clip(image[i], label_ids, text_inputs, clipmodel, config)
                        clip_logits = clip_logits.cpu().numpy()
                        n, t, c, h, w = catimage.shape
                        catimage = catimage.permute(0, 3, 1, 4, 2)
                        catimage = catimage.reshape(n*h, t*w, c)
                        catimage = catimage.cpu().numpy()
                        img_norm_cfg = dict(
                            mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
                        catimage = catimage * np.array(img_norm_cfg['std']) + np.array(img_norm_cfg['mean'])
                        catimage = catimage / 255.0
                        plt.figure(figsize=(25, 8))
                        plt.subplot(2,logits_matrix.shape[-1]+1,1)
                        plt.axis('off')
                        plt.imshow(catimage)
                        plt.title('label:' + label)
                        for j in range(logits_matrix.shape[-1]):
                            this_label = label_dict[str(label_id_neg[i][j].item())]
                            matrix = logits_matrix[i, :, :, j]
                            plt.subplot(2,logits_matrix.shape[-1]+1,j+2)
                            plt.imshow(matrix, vmin=0, vmax=1)
                            plt.xlabel('Second Clip Number')
                            plt.ylabel('First Clip Number')
                            plt.xticks(range(5), range(1, 6))
                            plt.yticks(range(5), range(1, 6))
                            for m in range(5):
                                for n in range(5):
                                        plt.text(n, m, f'{matrix[m][n]:.2f}', ha='center', va='center', color='white')
                            plt.colorbar()
                            plt.title(this_label)

                            plt.subplot(2,logits_matrix.shape[-1]+1,j+logits_matrix.shape[-1]+3)
                            clip_matrix = clip_logits[:,:,j]/100.
                            print('clip_matrix:', clip_matrix.shape)
                            plt.imshow(clip_matrix, vmin=0, vmax=1)
                            for m in range(clip_matrix.shape[0]):
                                for n in range(clip_matrix.shape[1]):
                                    plt.text(n, m, f'{clip_matrix[m][n]:.2f}', ha='center', va='center', color='white')
                            plt.colorbar()
                            plt.title('CLIP:'+ this_label)

                        plt.tight_layout()
                        plt.savefig(os.path.join(save_dir, f'{label}_{idx}_{i}.jpg'))
                        plt.clf()
                else:
                    print(title + ' finished!')
                    break          
    return
        

@torch.no_grad()
def validate(val_loader, text_labels, model, config, trainset=False, outdomain=False):
    model.eval()
    if config.DATA.USE_ORDER:
        if config.TRAIN.TRAIN_CLASS and config.TRAIN.TWO_HEAD:
            acc1_meter_class, acc5_meter_class, correct_rank_meter = AverageMeter(), AverageMeter(), AverageMeter()
        if config.TRAIN.TRAIN_ORDER:
            order_meter = AverageMeter()
        if not config.TRAIN.TWO_HEAD:
            order_meter_wrong = AverageMeter()
            acc_textwrong_meter = AverageMeter()
    else:
        acc1_meter, acc5_meter, correct_rank_meter = AverageMeter(), AverageMeter(), AverageMeter()
    with torch.no_grad():
        text_inputs = text_labels.cuda()
        for idx, batch_data in enumerate(val_loader):
            image = batch_data["imgs"] #(b, t, c, h, w) or (b, num_clips, t, c, h, w)
            label_id = batch_data["label"]
            label_id_neg = batch_data["neg_labels"]
            label_id = label_id.reshape(-1)
            if config.DATA.USE_ORDER:
                label_id = label_id.repeat(config.DATA.NUM_CLIPS_VAL * (config.DATA.NUM_CLIPS_VAL - 1) // 2)

            label_id = label_id.cuda(non_blocking=True)
            image_input = image.cuda(non_blocking=True)
            label_id_neg = label_id_neg.cuda(non_blocking=True)

            if config.TRAIN.OPT_LEVEL == 'O2':
                image_input = image_input.half()
            
            output = model(image_input, text_inputs, label_id, config.DATA.USE_ORDER, config.DATA.NUM_CLIPS_VAL, label_id_neg = label_id_neg)
            
            if config.DATA.USE_ORDER:
                logits_class, logits_order = output
                if logits_class is not None:
                    similarity = logits_class.softmax(dim=-1)
            else:
                similarity = output.softmax(dim=-1)

            if not config.DATA.USE_ORDER:
                values_1, indices_1 = similarity.topk(1, dim=-1)
                values_5, indices_5 = similarity.topk(5, dim=-1)
                acc1, acc5 = 0, 0
                for i in range(b):
                    if indices_1[i] == label_id[i]:
                        acc1 += 1
                    if label_id[i] in indices_5[i]:
                        acc5 += 1
            
                acc1_meter.update(float(acc1) / b * 100, b)
                acc5_meter.update(float(acc5) / b * 100, b)
            else:
                if config.TRAIN.TRAIN_CLASS and config.TRAIN.TWO_HEAD:
                    sorted_indices = torch.argsort(similarity, descending=True)
                    correct_rank = torch.where(sorted_indices == 0)[1]
                    b = correct_rank.shape[0]
                    correct_rank_meter.update(correct_rank.float().mean().item(), b)

                    values_1, indices_1 = similarity.topk(1, dim=-1)
                    values_5, indices_5 = similarity.topk(5, dim=-1)
                    acc1_class = torch.sum(indices_1 == 0)
                    acc5_class = torch.sum(indices_5 == 0)
                    
                    acc1_meter_class.update(float(acc1_class) / b * 100, b)
                    acc5_meter_class.update(float(acc5_class) / b * 100, b)
                if config.TRAIN.TRAIN_ORDER:
                    order_pred = logits_order.argmax(dim=-1)
                    acc_order = torch.sum(order_pred == 0)
                    b = order_pred.shape[0]
                    order_meter.update(float(acc_order) / b * 100, b)
                if not config.TRAIN.TWO_HEAD:
                    if logits_class is not None:
                        order_pred_wrong = logits_class.argmax(dim=-1)
                        acc_order_wrong = torch.sum(order_pred_wrong == 0)
                        b = order_pred_wrong.shape[0]
                        order_meter_wrong.update(float(acc_order_wrong) / b * 100, b)
                        acc_textwrong = torch.sum(order_pred_wrong == 2)
                        b = order_pred_wrong.shape[0]
                        acc_textwrong_meter.update(float(acc_textwrong) / b * 100, b)

            if idx % config.PRINT_FREQ == 0:
                logger.info(f'Test: [{idx}/{len(val_loader)}]\t')

    if not config.DATA.USE_ORDER:
        acc1_meter.sync()
        acc5_meter.sync()
        logger.info(f'* Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
        return acc1_meter.avg
    else:
        if config.TRAIN.TRAIN_CLASS and config.TRAIN.TWO_HEAD:
            acc1_meter_class.sync()
            acc5_meter_class.sync()
            correct_rank_meter.sync()
            logger.info(f'Acc_class@1 {acc1_meter_class.avg:.3f} Acc_class@5 {acc5_meter_class.avg:.3f} Correct_rank {correct_rank_meter.avg:.3f}')
        if config.TRAIN.TRAIN_ORDER:
            order_meter.sync()
            logger.info(f'Acc_order {order_meter.avg:.3f}')
        if not config.TRAIN.TWO_HEAD and config.DATA.NUM_NEGATIVE > 0:
            if outdomain and config.DATA.NUM_CLASSES_VAL > 1:
                order_meter_wrong.sync()
                logger.info(f'Acc_order_wrong {order_meter_wrong.avg:.3f}')
                return order_meter_wrong.avg, order_meter.avg, acc_textwrong_meter.avg
        if config.TRAIN.TRAIN_CLASS and config.TRAIN.TRAIN_ORDER:
            return acc1_meter_class.avg, order_meter.avg, correct_rank_meter.avg
        elif config.TRAIN.TRAIN_CLASS:
            return acc1_meter_class.avg, None, correct_rank_meter.avg  
        elif config.TRAIN.TRAIN_ORDER:
            return None, order_meter.avg, None
        


if __name__ == '__main__':
    # prepare config
    args, config = parse_option()

    # init_distributed
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    torch.distributed.barrier(device_ids=[args.local_rank])

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    # create working_dir
    Path(config.OUTPUT).mkdir(parents=True, exist_ok=True)
    
    # logger
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.ARCH}")
    logger.info(f"working dir: {config.OUTPUT}")
    
    # save config 
    if dist.get_rank() == 0 and config.DATA.USE_WANDB:
        logger.info(config)
        shutil.copy(args.config, config.OUTPUT)
        wandb.init(project="timerewarder", id=config.OUTPUT[9:], config=config)

    main(config)