import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import lr_scheduler


import torchvision
import torchvision.transforms as transforms
import lib.custom_transforms as custom_transforms

import os
import argparse
import time
from tqdm import tqdm
import json


from lib.NCEAverage import NCEAverage_fp_fs_pcl_5_norm_4
from lib.NCECriterion import NCESoftmaxLoss
from lib.utils import AverageMeter, adjust_learning_rate, accuracy
import lib.transforms as T
import lib.augmentation as A

from datasets.custom_dataset import RetrievalDataSet, SSLDataSet
from models.r21d import R2Plus1DNet_Focus_MPred_Pro_v2
from models.r3d import R3DNet_Focus_MPred_Pro_v2

from torch.utils.data import DataLoader, random_split

import random
import numpy as np
import ast

from config import *
from logger import Logger

def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--cluster_freq', type=int, default=5, help='cluster frequency')
    parser.add_argument('--tb_freq', type=int, default=50, help='tb frequency')
    parser.add_argument('--save_freq', type=int, default=60, help='save frequency')
    parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use') # 8
    parser.add_argument('--epochs', type=int, default=300, help='number of training epochs')

    # Device options
    parser.add_argument('--gpu-id', default='0', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')

    # optimization
    parser.add_argument('--opt_type', type=str, default='sgd', help='optimizer type')
    parser.add_argument('--learning_rate', type=float, default=0.1, help='learning rate')
    parser.add_argument('--pred_loss', type=str, default='l1', help='loss for prediction')
    parser.add_argument('--lr_ratio', type=float, default=0.001, help='learning rate ratio of proj layer')
    parser.add_argument('--lr_decay_epochs', type=int, nargs='+', default=[90, 180, 240], help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2, help='decay rate for learning rate')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--cos', type=ast.literal_eval, default=False, help='whether to use cos anealing')

    # resume path
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--focus_init', default='', type=str, metavar='PATH',
                        help='path to focus model checkpoint for initlization (default: none)')
    parser.add_argument('--model_postfix', default='', type=str,
                        help='postfix of model name (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')

    # model definition
    parser.add_argument('--model', type=str, default='r3d', choices=['r3d', 'c3d', 'r21d', 'r3d_norm'])
    parser.add_argument('--nce_k', type=int, default=1024)
    parser.add_argument('--nce_t', type=float, default=0.07)
    parser.add_argument('--nce_m', type=float, default=0.5)
    parser.add_argument('--conv_level', type=int, default=5, help='level of conv features for prediction')
    parser.add_argument('--return_conv', type=ast.literal_eval, default=False)
    parser.add_argument('--proj', type=int, default=0, help='0: linear, 1: relu+linear, 2: linear+relu')
    parser.add_argument('--proj_dim', type=int, default=512, help='dim of projection head output')
    parser.add_argument('--f_req_clust', type=int, default=None)
    parser.add_argument('--num_clusters', type=int, nargs='+', default=[1500, 1500, 1500], help='where to decay lr, can be a list')
    parser.add_argument('--fs_warmup_epoch', type=int, default=181)
    parser.add_argument('--faiss_m', type=float, default=1)
    parser.add_argument('--faiss_th', type=float, default=0.8)
    parser.add_argument('--f_decay_rate', type=float, default=0.9)
    parser.add_argument('--f_decay_epochs', type=int, nargs='+', default=[60, 120],
                        help='where to decay faiss_m, can be a list')
    parser.add_argument('--pro_p', type=float, default=1.0, help='power of cosine distance')
    parser.add_argument('--pro_clamp_value', type=float, default=0.0, help='clamp value of cosine distance')

    # dataset
    parser.add_argument('--dataset', type=str, default='ucf101', choices=['ucf101', 'hmdb51'])
    parser.add_argument('--eval_dataset', type=str, default='ucf101', choices=['ucf101', 'hmdb51'])
    parser.add_argument('--split', type=str, default='1', choices=['1', '2', '3'])
    parser.add_argument('--clip_len', type=int, default=16, help='number of frames in a clip')
    parser.add_argument('--crop_size', type=int, default=112, help='number of frames in a clip')
    parser.add_argument('--img_dim', type=int, default=196, help='number of frames in a clip')
    parser.add_argument('--bottom_area', type=float, default=0.175, help='number of frames in a clip')
    parser.add_argument('--flip_consist', type=ast.literal_eval, default=True)
    parser.add_argument('--crop_consist', type=ast.literal_eval, default=True)
    parser.add_argument('--jitter_consist', type=ast.literal_eval, default=True)

    # specify folder
    parser.add_argument('--model_name', type=str, default='', help='model name')
    parser.add_argument('--model_path', type=str, default='./ckpt/', help='path to save model')
    parser.add_argument('--tb_path', type=str, default='./logs/', help='path to tensorboard')

    # add new views
    parser.add_argument('--debug', type=ast.literal_eval, default=False)
    parser.add_argument('--modality', type=str, default='res', choices=['rgb', 'res', 'u', 'v'],
                                    help='Modality for View #2. (View #1 is a RGB video clip)')
    parser.add_argument('--neg', type=str, default='repeat', choices=['repeat', 'shuffle'])
    parser.add_argument('--seed', type=int, default=632)

    # focus related params
    parser.add_argument('--sample_num', type=int, default=1)
    parser.add_argument('--focus_level', type=int, default=0, choices=list(range(-1, 6)))
    parser.add_argument('--focus_res', type=ast.literal_eval, default=True)
    parser.add_argument('--f_sigma_div', type=float, default=3.0)
    parser.add_argument('--focus_num', type=int, default=5, choices=[1,5,9])

    # retrieve related params
    parser.add_argument('--eval_retrieve', type=ast.literal_eval, default=True)
    parser.add_argument('--retrieve', type=ast.literal_eval, default=False)
    parser.add_argument('--r_sample_num', type=int, default=10)
    parser.add_argument('--r_batch_size', type=int, default=8, help='retrieval batch_size')
    parser.add_argument('--feature_base', type=str, default='features', help='dir to store feature.npy')
    parser.add_argument('--r_return_conv', type=ast.literal_eval, default=True)



    opt = parser.parse_args()

    if opt.focus_level == -1 or opt.focus_num==1:
        opt.use_focus = False

    if opt.return_conv:
        opt.feat_dim = 9216
    else:
        opt.feat_dim = 512

    return opt


def set_model(args, n_data):
    # set the model
    if args.model == 'r3d':
        model = R3DNet_Focus_MPred_Pro_v2(layer_sizes=(1, 1, 1, 1), with_classifier=False, return_conv=args.return_conv,
                                          focus_level=args.focus_level, focus_res=args.focus_res,
                                          conv_level=args.conv_level, sample_num=args.sample_num,
                                          pro_p=args.pro_p, pro_clamp_value=args.pro_clamp_value)
    elif args.model == 'r21d':
        model = R2Plus1DNet_Focus_MPred_Pro_v2(layer_sizes=(1,1,1,1), with_classifier=False, return_conv=args.return_conv,
                                          focus_level=args.focus_level, focus_res=args.focus_res,
                                          conv_level=args.conv_level, sample_num=args.sample_num,
                                          pro_p=args.pro_p, pro_clamp_value=args.pro_clamp_value)

    contrast = NCEAverage_fp_fs_pcl_5_norm_4(args.feat_dim, n_data, args.nce_k, args.nce_t, args.nce_m, args.focus_num,
                            proj_dim=args.proj_dim)

    criterion = NCESoftmaxLoss()
    criterion_p = nn.CrossEntropyLoss()

    if args.pred_loss =='l1':
        criterion_pred = nn.L1Loss()
        args.pred_sign = 1.0
    elif args.pred_loss =='l2':
        criterion_pred = nn.MSELoss()
        args.pred_sign = 1.0
    elif args.pred_loss =='cos':
        criterion_pred = nn.CosineSimilarity(dim=1)
        args.pred_sign = -1.0

    # GPU mode
    model = torch.nn.DataParallel(model).cuda()
    contrast = torch.nn.DataParallel(contrast).cuda()
    criterion = criterion.cuda()
    criterion_p = criterion_p.cuda()
    cudnn.benchmark = True

    return model, contrast, criterion, criterion_p, criterion_pred


def set_optimizer_mp(args, model, contrast):
    # return optimizer
    if args.opt_type=='sgd':
        optimizer = torch.optim.SGD([{"params":model.parameters()},
                                     {"params":contrast.parameters(), "lr":args.learning_rate*args.lr_ratio}],
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        print('Using SGD ...')
    elif args.opt_type=='adam':
        optimizer = torch.optim.Adam([{"params": model.parameters()},
                                      {"params": contrast.parameters(), "lr": args.learning_rate * args.lr_ratio}],
                                     lr=args.learning_rate,
                                     weight_decay=args.weight_decay)
        print('Using Adam ...')
    return optimizer


def train(epoch, train_loader, model, contrast, criterion, criterion_p, criterion_pred, optimizer, opt, logger, transforms_cuda):
    """
    one epoch training
    """
    model.train()
    contrast.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    loss_c_meter = AverageMeter()
    loss_p_meter = AverageMeter()
    loss_total_meter = AverageMeter()
    prob_meter = AverageMeter()
    acc_proto = AverageMeter()

    bar = tqdm(train_loader)
    end = time.time()
    for idx, (inputs, f_maps, _, index, f_index) in enumerate(bar):
        data_time.update(time.time() - end)

        # inputs = inputs.reshape((-1, 3, 16, 112, 112))

        bsz = inputs.size(0)
        inputs = inputs.float().cuda()
        f_maps = f_maps.to(inputs.device)
        index = index.to(inputs.device)
        f_index = f_index.to(inputs.device)

        # ===================forward=====================

        # reshape inputs; for dim 0, v1 clips comes first, then, v2 clips, then, v3 clips, ...
        inputs = inputs.reshape((-1, 3, opt.clip_len, opt.crop_size, opt.crop_size))
        inputs = transforms_cuda(inputs)
        f_maps = f_maps.reshape((-1, 1, opt.crop_size, opt.crop_size))

        feat, feat_pred_1, feat_pred_2, feat_tgt_1, feat_tgt_2 = model(inputs, f_maps)

        out, out_proto, target_proto = \
            contrast(feat, index.view(-1)*torch.tensor(opt.focus_num).to(inputs.device)+f_index.view(-1))
                       # cluster_result=cluster_result)
        # out_l, out_ab, size (bs, nce_k+1, 1)


        loss_c = criterion(out)
        loss_p1 = criterion_pred(feat_pred_1, feat_tgt_1.detach()).mean()
        loss_p2 = criterion_pred(feat_pred_2, feat_tgt_2.detach()).mean()
        loss_p = opt.pred_sign * (loss_p1 + loss_p2) / 2.  # input, target
        loss = loss_c + loss_p
        prob = out[:, 0].mean()

        # ===================meters=====================
        loss_c_meter.update(loss_c.item(), bsz)
        loss_p_meter.update(loss_p.item(), bsz)
        loss_meter.update(loss.item(), bsz)
        prob_meter.update(prob.item(), bsz)

        if out_proto is not None:
            loss_proto = 0
            for proto_out, proto_target in zip(out_proto, target_proto):
                loss_proto += criterion_p(proto_out, proto_target)
                accp = accuracy(proto_out, proto_target)[0]
                acc_proto.update(accp[0], bsz)

            # average loss across all sets of prototypes
            loss_proto /= len(args.num_clusters)
            loss += loss_proto

        loss_total_meter.update(loss.item(), bsz)

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



        batch_time.update(time.time() - end)
        end = time.time()

        if idx % opt.tb_freq == 0:
            logger.log_value('loss', loss_meter.avg, idx)
            logger.log_value('prob', prob_meter.avg, idx)
            logger.log_value('loss_total', loss_total_meter.avg, idx)
            logger.log_value('acc_proto', acc_proto.avg, idx)
            logger.log_value('feat_grad', contrast.module.feature_proj[0].weight.grad.abs().mean().item(), idx)


        bar.set_description('Train: [{0}/{1}][{2}/{3}]|'
                        'BS {batch_size}|SN {sample_num}|'
                        'l {loss.val:.3f} ({loss.avg:.3f})|'
                        'l_c {loss_c.val:.3f} ({loss_c.avg:.3f})|'
                        'l_p {loss_p.val:.3f} ({loss_p.avg:.3f})|'
                        'l_t {loss_total.val:.3f} ({loss_total.avg:.3f})|'
                        'ac_p {acc_proto.val:.3f} ({acc_proto.avg:.3f})|'
                        'prob {prob.val:.3f} ({prob.avg:.3f})|'.format(
        epoch, opt.epochs, idx + 1, len(train_loader), batch_size=bsz, sample_num=opt.sample_num,
        loss_total = loss_total_meter, acc_proto = acc_proto,
        loss=loss_meter, loss_c=loss_c_meter, loss_p=loss_p_meter, prob=prob_meter))

    return loss_meter.avg, prob_meter.avg, loss_total_meter.avg, acc_proto.avg


args = parse_option()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

def main(args):
    if not torch.cuda.is_available():
        raise Exception('Only support GPU mode')

    gpu_num = torch.cuda.device_count()
    if not args.model_name:
        if args.use_focus:
            print('[Warning] using focus models')
            args.model_name = 'focus_{}_{}'.format(args.model, time.strftime('%m%d'))

        else:
            print('[Warning] using baseline')
            args.model_name = '{}_{}'.format(args.model, time.strftime('%m%d'))


    args.model_name = args.model_name + args.model_postfix

    args.model_folder = os.path.join(args.model_path, args.model_name)
    if not os.path.isdir(args.model_folder):
        os.makedirs(args.model_folder)

    args.tb_folder = os.path.join(args.tb_path, args.model_name)
    if not os.path.isdir(args.tb_folder):
        os.makedirs(args.tb_folder)

    # parse the args
    print(vars(args))

    # Fix all parameters for reproducibility
    seed = args.seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    print('[Warning] The training modalities are RGB and [{}]'.format(args.modality))

    ''' Data '''
    train_transforms = transforms.Compose([
        A.RandomSizedCrop(size=args.crop_size, consistent=args.crop_consist, seq_len=args.clip_len,
                          bottom_area=args.bottom_area),
        transforms.RandomApply([
            A.ColorJitter(0.4, 0.4, 0.4, 0.1, p=1.0, consistent=args.jitter_consist, seq_len=args.clip_len)
        ], p=0.8),
        transforms.RandomApply([A.GaussianBlur([.1, 2.], seq_len=args.clip_len)], p=0.5),
        A.RandomHorizontalFlip(consistent=args.flip_consist, seq_len=args.clip_len),
        A.ToTensor(),
    ])

    transform_train_cuda = transforms.Compose([
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225], channel=1)])

    trainset = SSLDataSet(UCF101_PATH, transforms_=train_transforms, sample_num=args.sample_num, focus_num=args.focus_num,
                          split=args.split)


    print('gpu num:', gpu_num)
    train_loader = DataLoader(trainset, batch_size=args.batch_size*gpu_num, shuffle=True,
                                        num_workers=args.num_workers, pin_memory=True, drop_last=True)

    n_data = trainset.__len__()

    # prepare dataloders for retrieval
    if args.eval_retrieve:
        train_transforms_r = transforms.Compose([
            A.CenterCrop(size=(args.img_dim, args.img_dim)),
            A.Scale(size=(args.crop_size, args.crop_size)),
            A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True),
            A.ToTensor()])

        if args.eval_dataset == 'ucf101':
            train_dataset_r = RetrievalDataSet(UCF101_PATH, clip_len=16, transforms_=train_transforms_r, train=True,
                                               sample_num=args.r_sample_num, focus_num=args.focus_num, retrieve=True,
                                               split=args.split)

        elif args.eval_dataset == 'hmdb51':
            train_dataset_r = RetrievalDataSet(HMDB51_PATH, clip_len=16, transforms_=train_transforms_r, train=True,
                                               sample_num=args.r_sample_num, focus_num=args.focus_num, retrieve=True,
                                               split=args.split)

        train_dataloader_r = DataLoader(train_dataset_r, batch_size=args.r_batch_size*gpu_num, shuffle=False,
                                      num_workers=args.num_workers, pin_memory=False, drop_last=True)

        test_transforms_r = transforms.Compose([
            A.CenterCrop(size=(args.img_dim, args.img_dim)),
            A.Scale(size=(args.crop_size, args.crop_size)),
            A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True),
            A.ToTensor()])

        if args.eval_dataset == 'ucf101': # or args.dataset == 'k200':
            test_dataset_r = RetrievalDataSet(UCF101_PATH, clip_len=16, transforms_=test_transforms_r, train=False,
                                              sample_num=args.r_sample_num, focus_num=args.focus_num, retrieve=True,
                                              split=args.split)
        elif args.eval_dataset == 'hmdb51':
            test_dataset_r = RetrievalDataSet(HMDB51_PATH, clip_len=16, transforms_=test_transforms_r, train=False,
                                              sample_num=args.r_sample_num, focus_num=args.focus_num, retrieve=True,
                                              split=args.split)
        test_dataloader_r = DataLoader(test_dataset_r, batch_size=args.r_batch_size*gpu_num, shuffle=False,
                                     num_workers=args.num_workers, pin_memory=False, drop_last=True)

    # prepare dataloaders for updating memory_curr
    train_transforms_u = transforms.Compose([
        A.CenterCrop(size=(args.img_dim, args.img_dim)),
        A.Scale(size=(args.crop_size, args.crop_size)),
        A.ColorJitter(0.2, 0.2, 0.2, 0.1, p=0.3, consistent=True),
        A.ToTensor()])
    if args.dataset == 'ucf101':
        train_dataset_u = RetrievalDataSet(UCF101_PATH, clip_len=args.clip_len,
                                           transforms_=train_transforms_u, train=True,
                                           sample_num=args.sample_num, focus_num=args.focus_num,
                                           retrieve=False, split=args.split)

    elif args.dataset == 'hmdb51':
        train_dataset_u = RetrievalDataSet(HMDB51_PATH, clip_len=args.clip_len,
                                           transforms_=train_transforms_u, train=True,
                                           sample_num=args.sample_num, focus_num=args.focus_num,
                                           retrieve=False, split=args.split)

    train_loader_clust = DataLoader(train_dataset_u, batch_size=args.batch_size * gpu_num, shuffle=False,
                                 num_workers=args.num_workers, pin_memory=False, drop_last=False)
    # set the model
    model, contrast, criterion, criterion_p, criterion_pred = set_model(args, n_data)

    # set the optimizer
    optimizer = set_optimizer_mp(args, model, contrast)

    args.start_epoch = 1
    if args.focus_init:
        init_ckpt_path = os.path.join(args.model_path, args.focus_init)
        if os.path.isfile(init_ckpt_path):
            print("=> loading checkpoint '{}'".format(args.focus_init))
            checkpoint = torch.load(init_ckpt_path, map_location='cpu')
            args.start_epoch = checkpoint['epoch'] + 1
            assert args.focus_level == checkpoint['f_level']
            assert args.focus_res == checkpoint['f_res']
            assert args.focus_num == checkpoint['f_num']
            model.load_state_dict(checkpoint['model'])
            # optimizer.load_state_dict(checkpoint['optimizer']) # comment this to allow using new lr
            contrast.load_state_dict(checkpoint['contrast'], strict=False)
            print("=> loaded focus init checkpoint '{}'...".format(args.focus_init))
            del checkpoint
        else:
            print("=> no focus init checkpoint found at '{}'".format(args.focus_init))
            exit(1)

    # optionally resume from a checkpoint
    title = args.dataset + '-' + args.model_name
    if args.resume:  # memory_curr is ready, self.cluster_result=None
        ckpt_path = os.path.join(args.model_folder, args.resume)
        # args.resume must is 'ckpt_epoch_{epoch}.pth', model_name is within model_folder
        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            args.start_epoch = checkpoint['epoch'] + 1
            # args.fs_warmup_epoch = checkpoint['fs_warmup_epoch']
            args.focus_level = checkpoint['f_level']
            args.focus_res = checkpoint['f_res']
            args.focus_num = checkpoint['f_num']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'], strict=False)
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
            if not args.evaluate:
                with torch.no_grad():
                    contrast.module.update_clust(args.num_clusters)

            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(1)
    else:
        print("=> training from scratch ...")

    if args.evaluate:
        with torch.no_grad():
            test_retrieval(args.start_epoch-1, model, contrast, train_dataloader_r, test_dataloader_r, args, transform_train_cuda)
        return

    # Training code pending ...

    print(args.model_name)


def test_retrieval(epoch, model, contrast, train_dataloader, test_dataloader, args, transforms_cuda):
    if not os.path.exists(os.path.join(args.feature_base, args.eval_dataset)):
        os.makedirs(os.path.join(args.feature_base, args.eval_dataset))
    args.feature_dir = os.path.join(args.feature_base, args.eval_dataset, args.model_name+'_e%d'%epoch)
    if not os.path.exists(args.feature_dir):
        os.makedirs(args.feature_dir)

    if args.eval_dataset == 'ucf101':
        class_num = 101
    elif args.eval_dataset == 'hmdb51':
        class_num = 51

    model.eval()
    contrast.eval()

    model.module.focus_level = -1

    # ===== extract training features =====
    features = []
    features_proj = []
    classes = []
    for data in tqdm(train_dataloader):
        # sampled_clips, idxs, _ = data
        sampled_clips, f_maps, idxs, _, _ = data
        clips = sampled_clips.reshape((-1, 3, args.clip_len, args.crop_size, args.crop_size))
        inputs = clips.cuda()
        inputs = transforms_cuda(inputs)
        f_maps = f_maps.reshape((-1, 1, args.crop_size, args.crop_size)).to(inputs.device)

        outputs, _, _, _, _ = model(inputs, f_maps)
        outputs_proj = contrast(outputs, None, None, mode='eval')
        # print('outputs', outputs.size(), 'outputs_proj', outputs_proj.size())
        # perform mean among all samples before saving
        if args.r_sample_num > 1:
            outputs = outputs.reshape((-1, args.r_sample_num, outputs.size(1)))
            outputs = torch.mean(outputs, dim=1)
            outputs_proj = outputs_proj.reshape((-1, args.r_sample_num, outputs_proj.size(1)))  # [80, 9216]->[8,10,9216]
            outputs_proj = torch.mean(outputs_proj, dim=1)
            idxs = idxs[:, 0]

        features.append(outputs.cpu().detach().numpy().tolist())
        features_proj.append(outputs_proj.cpu().detach().numpy().tolist())
        classes.append(idxs.cpu().detach().numpy().tolist())


    X_train = np.array(features).reshape(-1, outputs.size(1))
    X_train_proj = np.array(features_proj).reshape(-1, outputs_proj.size(1))
    y_train = np.array(classes).reshape(-1)
    # np.save(os.path.join(args.feature_dir, 'train_feature.npy'), X_train)
    # np.save(os.path.join(args.feature_dir, 'train_feature_proj.npy'), X_train_proj)
    # np.save(os.path.join(args.feature_dir, 'train_class.npy'), y_train)

    # ===== extract testing features =====
    features = []
    features_proj = []
    classes = []
    for data in tqdm(test_dataloader):
        sampled_clips, f_maps, idxs, _, _ = data

        clips = sampled_clips.reshape((-1, 3, args.clip_len, args.crop_size, args.crop_size))
        inputs = clips.cuda()
        inputs = transforms_cuda(inputs)
        f_maps = f_maps.reshape((-1, 1, args.crop_size, args.crop_size)).to(inputs.device)
        # forward
        outputs, _, _, _, _ = model(inputs, f_maps)
        outputs_proj = contrast(outputs, None, None, mode='eval')
        # perform mean among all samples before saving
        if args.r_sample_num > 1:
            outputs = outputs.reshape((-1, args.r_sample_num, outputs.size(1)))
            outputs = torch.mean(outputs, dim=1)
            outputs_proj = outputs_proj.reshape((-1, args.r_sample_num, outputs_proj.size(1)))
            outputs_proj = torch.mean(outputs_proj, dim=1)
            idxs = idxs[:, 0]

        features.append(outputs.cpu().detach().numpy().tolist())
        features_proj.append(outputs_proj.cpu().detach().numpy().tolist())
        classes.append(idxs.cpu().detach().numpy().tolist())


    X_test = np.array(features).reshape(-1, outputs.size(1))
    X_test_proj = np.array(features_proj).reshape(-1, outputs_proj.size(1))
    y_test = np.array(classes).reshape(-1)
    # np.save(os.path.join(args.feature_dir, 'test_feature.npy'), X_test)
    # np.save(os.path.join(args.feature_dir, 'test_feature_proj.npy'), X_test_proj)
    # np.save(os.path.join(args.feature_dir, 'test_class.npy'), y_test)
    print('Saving features to ...', args.feature_dir)

    model.module.focus_level = args.focus_level
    del features
    del features_proj
    del classes
    '''
    X_train = np.load(os.path.join(args.feature_dir, 'train_feature.npy'))
    X_train_proj = np.load(os.path.join(args.feature_dir, 'train_feature_proj.npy'))
    y_train = np.load(os.path.join(args.feature_dir, 'train_class.npy'))

    X_test = np.load(os.path.join(args.feature_dir, 'test_feature.npy'))
    X_test_proj = np.load(os.path.join(args.feature_dir, 'test_feature_proj.npy'))
    y_test = np.load(os.path.join(args.feature_dir, 'test_class.npy'))
    '''
    train_feature = torch.tensor(X_train).cuda()
    test_feature = torch.tensor(X_test).to(train_feature.device)

    train_feature_proj = torch.tensor(X_train_proj).to(train_feature.device)
    test_feature_proj = torch.tensor(X_test_proj).to(train_feature.device)

    y_train = torch.tensor(y_train).to(train_feature.device)
    y_test = torch.tensor(y_test).to(train_feature.device)

    ks = [1, 5, 10, 20, 50]
    topk_correct = {k: 0 for k in ks}
    topk_correct_proj = {k: 0 for k in ks}

    # centering
    test_feature = test_feature - test_feature.mean(dim=0, keepdim=True)
    train_feature = train_feature - train_feature.mean(dim=0, keepdim=True)

    test_feature_proj = test_feature_proj - test_feature_proj.mean(dim=0, keepdim=True)
    train_feature_proj = train_feature_proj - train_feature_proj.mean(dim=0, keepdim=True)

    # normalize
    test_feature = F.normalize(test_feature, p=2, dim=1)
    train_feature = F.normalize(train_feature, p=2, dim=1)

    test_feature_proj = F.normalize(test_feature_proj, p=2, dim=1)
    train_feature_proj = F.normalize(train_feature_proj, p=2, dim=1)

    # dot product
    sim = test_feature.matmul(train_feature.t())
    sim_proj = test_feature_proj.matmul(train_feature_proj.t())

    topk_result_proj = []
    top_k_stat = {k:None for k in ks}
    print('----- feature_proj dim %d -----' % test_feature_proj.size(-1))
    for k in ks:
        topkval, topkidx = torch.topk(sim_proj, k, dim=1)
        # pdb.set_trace()
        result = torch.any(y_train[topkidx] == y_test.unsqueeze(1), dim=1)
        topk_result_proj.append(result.unsqueeze(0))
        class_statistics = {idx: result[y_test==idx].float().mean().item() for idx in range(class_num)}
        top_k_stat[k] = class_statistics
        acc = result.float().mean().item()
        topk_correct_proj[k] = acc
        print('Top-%d acc = %.4f' % (k, acc))

    total = test_feature.size(0)
    topk_correct.update({"total": total})
    topk_correct_proj.update({"total": total})

    with open(os.path.join(args.feature_dir, 'topk_correct.json'), 'w', encoding='utf-8') as fp:
        json.dump(topk_correct, fp)

    with open(os.path.join(args.feature_dir, 'topk_correct_proj.json'), 'w', encoding='utf-8') as fp:
        json.dump(topk_correct_proj, fp)

    with open(os.path.join(args.feature_dir, 'topk_stat_proj.json'), 'w', encoding='utf-8') as fp:
        json.dump(top_k_stat, fp)

    topk_result_proj = torch.cat(topk_result_proj, dim=0)
    topk_result_proj_np = topk_result_proj.t().cpu().numpy()
    np.savetxt(os.path.join(args.feature_dir, 'topk_result_proj.txt'),
               topk_result_proj_np, fmt='%d', delimiter='\t')



if __name__ == '__main__':
    main(args)
