from utils import img_to_vid, stderr_suppress, vid_to_img, scrape_dirs, gen_temp
import os
import time
import string
import random

from options import get_args
import datasets.video_corrupt as video_corrupt
import multiprocessing
import subprocess
from tqdm import tqdm


def _worker_init(local_args):
    global args
    global ext
    global vid_name
    global corrupt_name
    random.seed(os.getpid() + time.time())
    ext = ".avi"
    vid_name = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16)) + ext
    corrupt_name = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16)) + ext
    args = local_args

def form_corruption_name(args):
    return '_'.join([args.corrupt_mode, str(args.corrupt_prob)])

def process_file(fname):
    corruption_name = form_corruption_name(args)
    base_corrupted_path = os.path.join(args.base_path, 'file_corruptions', corruption_name)
    if os.path.isdir(fname):
        """
            Turn image sequence into videos
        """
        print(fname, vid_name)
        temp_video = img_to_vid(args, fname, vid_name)

        """
            Construct the new path for the final image sequence
        """
        new_path = os.path.join(base_corrupted_path, *(fname.split('/')[-5:]))
        if not os.path.exists(new_path):
            try:
                os.makedirs(new_path)
            except OSError as exc: # Guard against race condition
                 pass

        """
            Construct path for corrupted video
        """
        corrupt_video_path = os.path.join('file_corruption_temp', corrupt_name)
        assert not os.path.isfile(corrupt_video_path)

        """
            Transcode to H.264
        """
        random.seed(time.time() + os.getpid())
        rand_string = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16))
        tmp = os.path.join('file_corruption_temp', rand_string + ".mp4")
        gen_temp(temp_video, tmp)

        """
            Corrupt video
        """
        while not os.path.isfile(corrupt_video_path):
            _ = video_corrupt.flip(tmp, corrupt_video_path, 'h264', mode=args.corrupt_mode, p=args.corrupt_prob)
            os.unlink(tmp)
            if not args.enforce_readability: break
        os.unlink(temp_video)

        """
            Transform corrupted video into images
        """
        new_dir = vid_to_img(args, corrupt_video_path, new_path)
        os.unlink(corrupt_video_path)
        return new_dir
    else:
        raise ValueError("{} is not a file or directory".format(fname))

def create_dataset(fnames, args):
    if args.limit: fnames = fnames[:args.limit]
    with multiprocessing.Pool(args.num_workers, initializer=_worker_init, initargs=(args,)) as p:
        with tqdm(total=len(fnames)) as pbar:
            for i, _ in enumerate(p.imap_unordered(process_file, fnames)):
                pbar.update()

if __name__ == '__main__':
    print("Creating corrupted dataset...")
    args = get_args()
    fnames = scrape_dirs(args)
    create_dataset(fnames, args)
