# Copyright Howto100M authors.
# Copyright (c) Facebook, Inc. All Rights Reserved

import torch as th
import torch.nn.functional as F
import math
import numpy as np
import argparse

from torch.utils.data import DataLoader
from model import get_model
from preprocessing import Preprocessing
from random_sequence_shuffler import RandomSequenceSampler

from tqdm import tqdm
from pathbuilder import PathBuilder
from videoreader import VideoLoader


parser = argparse.ArgumentParser(description='Easy video feature extractor')

parser.add_argument('--vdir', type=str)
parser.add_argument('--fdir', type=str)
parser.add_argument('--hflip', type=int, default=0)

parser.add_argument('--batch_size', type=int, default=64,
                            help='batch size')
parser.add_argument('--type', type=str, default='2d',
                            help='CNN type')
parser.add_argument('--half_precision', type=int, default=0,
                            help='output half precision float')
parser.add_argument('--num_decoding_thread', type=int, default=4,
                            help='Num parallel thread for video decoding')
parser.add_argument('--l2_normalize', type=int, default=1,
                            help='l2 normalize feature')
parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth',
                            help='Resnext model path')
parser.add_argument('--vmz_model_path', type=str, default='model/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth',
                            help='vmz model path')

args = parser.parse_args()


# TODO: refactor all args into config. (current code is from different people.)
CONFIGS = {
    "2d": {
        "fps": 1,
        "size": 224,
        "centercrop": False,
        "shards": 0,
    },
    "3d": {
        "fps": 24,
        "size": 112,
        "centercrop": True,
        "shards": 0,
    },
    "s3d": {
        "fps": 30,
        "size": 224,
        "centercrop": True,
        "shards": 0,
    },
    "vmz": {
        "fps": 24,
        "size": 112,
        "centercrop": True,
        "shards": 0,
    },
    "vae": {
        "fps": 2,
        "size": 256,
        "centercrop": True,
        "shards": 100,
    }
}

config = CONFIGS[args.type]


video_dirs = args.vdir
feature_dir = args.fdir

video_dict = PathBuilder.build(video_dirs, feature_dir, ".npy", config["shards"])

dataset = VideoLoader(
    video_dict=video_dict,
    framerate=config["fps"],
    size=config["size"],
    centercrop=config["centercrop"],
    hflip=args.hflip
)
n_dataset = len(dataset)
sampler = RandomSequenceSampler(n_dataset, 10)
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=args.num_decoding_thread,
    sampler=sampler if n_dataset > 10 else None,
)
preprocess = Preprocessing(args.type)
model = get_model(args)

with th.no_grad():
    for k, data in tqdm(enumerate(loader), total=loader.__len__(), ascii=True):
        input_file = data['input'][0]
        output_file = data['output'][0]
        if len(data['video'].shape) > 3:
            video = data['video'].squeeze()
            if len(video.shape) == 4:
                video = preprocess(video)
                n_chunk = len(video)
                if args.type == 'vmz':
                    n_chunk = math.ceil(n_chunk/float(3))
                    features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
                elif args.type == 's3d':
                    features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
                elif args.type == "vae":
                    features = th.cuda.LongTensor(n_chunk, 1024).fill_(0)
                else:
                    features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0)
                n_iter = int(math.ceil(n_chunk / float(args.batch_size)))
                for i in range(n_iter):
                    factor = 1
                    if args.type == 'vmz':
                        factor = 3
                    min_ind = factor * i * args.batch_size
                    max_ind = factor * (i + 1) * args.batch_size
                    video_batch = video[min_ind:max_ind:factor].cuda()
                    if args.type == '2d':
                        batch_features = model(video_batch) # (51, 487), (51, 512)
                    elif args.type == 's3d':
                        batch_features = model(video_batch)
                        batch_features = batch_features['video_embedding']
                    elif args.type == "vae":
                        # image_code.
                        batch_features = model(video_batch)
                    else:
                        batch_pred, batch_features = model(video_batch) # (51, 487), (51, 512)
                    if args.l2_normalize:
                        batch_features = F.normalize(batch_features, dim=1)
                    features[i*args.batch_size:(i+1)*args.batch_size] = batch_features
                features = features.cpu().numpy()
                if args.half_precision:
                    if args.type == "vae":
                        features = features.astype(np.int16)
                    else:
                        features = features.astype('float16')
                else:
                    if args.type == "vae":
                        features = features.astype(np.int32)
                    else:
                        features = features.astype('float32')
                np.save(output_file, features)
        else:
            print('Video {} error.'.format(input_file))
