# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu

from __future__ import print_function, division
import os
import sys
import subprocess
from multiprocessing import Pool
from tqdm import tqdm

n_thread = os.cpu_count()


def vid2jpg(file_name, class_path, dst_class_path):
    if '.mp4' not in file_name:
        return

    video_file_path = os.readlink(os.path.join(class_path, file_name))
    name = os.path.splitext(os.path.basename(video_file_path))[0]
    dst_directory_path = os.path.join(dst_class_path, name)

    try:
        if os.path.exists(dst_directory_path):
            if not os.path.exists(os.path.join(dst_directory_path, '00001.jpg')):
                subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
                print('remove {}'.format(dst_directory_path))
                os.mkdir(dst_directory_path)
            else:
                print('*** convert has been done: {}'.format(dst_directory_path))
                return
        else:
            os.mkdir(dst_directory_path)
    except:
        print(dst_directory_path)
        return
    cmd = 'ffmpeg -i \"{}\" -threads 1 -vf scale=-1:331 -q:v 0 \"{}/%05d.jpg\"'.format(video_file_path,
                                                                                       dst_directory_path)
    # print(cmd)
    subprocess.call(cmd, shell=True,
                    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


def class_process(dir_path, dst_dir_path, class_name):
    print('*' * 20, class_name, '*' * 20)
    class_path = os.path.join(dir_path, class_name)
    if not os.path.isdir(class_path):
        print('*** is not a dir {}'.format(class_path))
        return

    dst_class_path = os.path.join(dst_dir_path, class_name)
    if not os.path.exists(dst_class_path):
        os.mkdir(dst_class_path)

    vid_list = os.listdir(class_path)
    vid_list.sort()
    p = Pool(n_thread)
    from functools import partial
    worker = partial(vid2jpg, class_path=class_path, dst_class_path=dst_class_path)
    for _ in tqdm(p.imap_unordered(worker, vid_list), total=len(vid_list)):
        pass

    p.close()
    p.join()

    print('\n')


if __name__ == "__main__":
    dir_path = '/data/whj/CIL_IN_VIDEO/PIVOT-main/kinetics-dataset/k400/videos'  # 视频路径
    dst_dir_path = '/data/whj/CIL_IN_VIDEO/PIVOT-main/datasets/path_frames/kinetics'  # 提取帧的保存路径

    train_type = os.listdir(dir_path)
    train_type.sort()
    for each_type in train_type:
        type_dir_path = os.path.join(dir_path, each_type)
        type_dst_dir_path = os.path.join(dst_dir_path, each_type)

        class_list = os.listdir(type_dir_path)
        class_list.sort()

        for class_name in class_list:
            class_process(type_dir_path, type_dst_dir_path, class_name)
