from load_data import load_data
from options import get_args
from tqdm import tqdm

import os
import pickle
import multiprocessing
import subprocess
import string

from options import get_network_corrupter_args
import time
import random
from utils import img_to_vid, vid_to_img, stderr_suppress, scrape_dirs

from contextlib import suppress

def network_corrupt(src, dest, mode, link, rate=0, mean_on_time=0, mean_off_time=0, port=12345, timeout=20, output_codec='libx264', n_loops=0, ffmpeg_log_level='quiet', framerate=None, verbose=False):
    if link not in ['uplink', 'downlink']: raise ValueError("Parameter #3 (link) must be either 'uplink' or 'downlink'")
    if mode not in ['loss', 'onoff']: raise ValueError("Parameter #2 (mode) must be either 'delay' or 'loss'")
    mm_cmd = None
    if mode == 'loss':
        mm_cmd = 'mm-loss'
        args = [str(rate)]
    elif mode == 'onoff':
        mm_cmd = 'mm-onoff'
        args = [str(mean_on_time), str(mean_off_time)]
    else:
        raise NotImplementedError
    ip_addr = subprocess.check_output(['hostname', '-I']).decode('utf-8').strip().split(' ')[0]
    mm_base = [mm_cmd, link] + args
    mm_base += ['ffmpeg', '-y', '-loglevel', ffmpeg_log_level, '-fflags', '+genpts', '-re', '-stream_loop', str(n_loops), '-i', src, '-f', 'rtsp', '-vcodec', output_codec, '-strict', '2', '-rtsp_transport', 'udp', 'rtsp://{}:{}/live.sdp'.format(ip_addr, port)]
    base_client_cmd = ['ffmpeg', '-fflags', '+genpts', '-y', '-re', '-loglevel', ffmpeg_log_level]
    if framerate:
        base_client_cmd += ['-r', str(framerate)]
    client_cmd = base_client_cmd + ['-rtsp_flags', 'listen', '-timeout', str(timeout),'-i', 'rtsp://{}:{}/live.sdp'.format(ip_addr, port), '-c', 'copy', dest]
    if verbose:
        print("Client command:", client_cmd)
        print("Server command", mm_base)
        cm = suppress()
    else:
        cm = stderr_suppress()
    with cm:
        client = subprocess.Popen(client_cmd)
        server = subprocess.Popen(mm_base)
        client.wait()
        server.wait()

def set_process_port(global_args):
    global port
    global args
    port = global_args.port + int(multiprocessing.current_process()._identity[0])
    args = global_args
    if args.dataset not in RAW_VIDEO_DATASETS: # we'll have to create a bunch of temp videos; give each process a random name for their temp videos
        random.seed(os.getpid() + time.time())
        global ext
        global vid_name 
        global corrupt_name
        ext = ".avi"
        vid_name = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16)) + ext
        corrupt_name = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16)) + ext


def form_corruption_name(args):
    params = 'on{}_off{}'.format(*args.onofftime) if args.network_error_mode == 'onoff' else str(args.packet_loss_rate)
    name_args = [args.network_error_mode, args.link_mode, params]
    if args.stream_loops > 0: name_args.append("loop{}".format(args.stream_loops))
    name_args.append("ver{}".format(args.corruption_version))
    return '_'.join(name_args)

def process_file(fname):
    corruption_name = form_corruption_name(args)
    base_corrupted_path = os.path.join(args.base_path, 'network_corruptions', corruption_name)
    if os.path.isfile(fname):
        new_path = os.path.join(base_corrupted_path, args.dataset, *(fname.split('/')[-2:]))
        if not os.path.exists(os.path.dirname(new_path)):
            try:
                 os.makedirs(os.path.dirname(new_path))
            except OSError as exc: # Guard against race condition
                 pass
        network_corrupt(fname, new_path, args.network_error_mode, args.link_mode, rate=args.packet_loss_rate, mean_on_time=args.onofftime[0], mean_off_time=args.onofftime[1], port=port, verbose=(args.limit is not None), ffmpeg_log_level=args.ffmpeg_log_level, n_loops=args.stream_loops, framerate=None)
        return new_path
    elif os.path.isdir(fname):
        temp_video = img_to_vid(args, fname, vid_name)
        new_path = os.path.join(base_corrupted_path, *(fname.split('/')[-5:]))
        print(fname, "to", new_path)
        if not os.path.exists(new_path):
            try:
                os.makedirs(new_path)
            except OSError as exc: # Guard against race condition
                 pass
        corruption_path = os.path.join('network_corruption_temp', corrupt_name)
        assert not os.path.isfile(corruption_path)
        while not os.path.isfile(corruption_path):
            network_corrupt(temp_video, corruption_path, args.network_error_mode, args.link_mode, rate=args.packet_loss_rate, mean_on_time=args.onofftime[0], mean_off_time=args.onofftime[1], port=port, verbose=(args.limit is not None), ffmpeg_log_level=args.ffmpeg_log_level, n_loops=args.stream_loops, framerate=args.stream_framerate)
            if not args.enforce_readability: break
        os.unlink(temp_video)
        _ = vid_to_img(args, corruption_path, new_path)
        os.unlink(corruption_path)
        return new_path
    else:
        raise ValueError("{} is not a file or directory".format(fname))

def create_dataset(fnames, args):
    if args.limit: fnames = fnames[:args.limit]
    ip_addr = subprocess.check_output(['hostname', '-I']).decode('utf-8').strip().split(' ')[0]
    with multiprocessing.Pool(args.num_workers, initializer=set_process_port, initargs=(args,)) as p:
        with tqdm(total=len(fnames)) as pbar:
            for i, _ in enumerate(p.imap_unordered(process_file, fnames)):
                pbar.update()

RAW_VIDEO_DATASETS = ['hmdb51', 'ucf101']
FRAME_BASED_DATASETS = ['mot15', 'mot20']
if __name__ == '__main__':
    args = get_args()
    if args.dataset in RAW_VIDEO_DATASETS:
        if os.path.exists("{}_fnames.pkl".format(args.dataset)) and not args.rebuild_filename_cache:
            print("Filename cache found")
            fnames = pickle.load(open("{}_fnames.pkl".format(args.dataset), 'rb'))
        else:
            print("Filename cache not found. Loading data...")
            _, val = load_data(args)
            fnames = val.data_labels[:, 0]
            pickle.dump(fnames, open("{}_fnames.pkl".format(args.dataset), 'wb'))
        print("Creating corrupted dataset...")
    elif args.dataset in FRAME_BASED_DATASETS:
        fnames = scrape_dirs(args)
    else:
        raise ValueError("No procedure defined for corrupting dataset '{}'.".format(args.dataset))
    create_dataset(fnames, args)


