import os
import sys
import subprocess
import warnings
from operator import itemgetter
import re
import itertools
import random
import shlex

import numpy as np
from tqdm import tqdm

# DIVX-family (MPEG4 Part 2) magic patterns
vop_start = bytearray.fromhex('000001b6')
vop_code_mask = 0xc0
vop_right_shift = 6
vop_offset = len(vop_start)
frame_start = b'00d'
movi_start = b'movi'
movi_end = b'idx1'

# AVC (MPEG4 Part 10) magic patterns
iframe_index_table_header = b'stss'
frame_size_header = b'stsz'
frame_offset_header = b'stco'
int_width = 4
nal_start = br'\x00\x00\x01|\x00\x00\x02|\x00\x00\x03'

def intify(x):
        return int.from_bytes(x, byteorder='big', signed=False)

def get_ext(path):
        return path.split(".")[-1]

def fill_corrupt(path, dest, start, length, pad_val=0):
    with open(path, 'rb') as f:
        ext = get_ext(path)
        a = bytearray(f.read())
        #check_params(ext, start, len(a))
        a[start:start+length] = [pad_val] * length
        if dest:
            with open(dest, 'wb') as g:
                g.write(a)


def flip(path, dest, codec=None, mode='random', p=1.):
    if codec is None: codec = get_codec(path)
    raw_frames = get_iframe_locations(path, codec)
    corruption_info = {}
    corrupted_iframe = False
    with open(path, 'rb') as f:
        a = bytearray(f.read())
        old_a = a.copy()
        if p != 0:
            if mode == 'contiguous':
                a, locations, corrupted_iframe = contiguous_corrupt(a, 0, len(a), p, raw_frames)
            elif mode == 'random':
                a, locations, corrupted_iframe = random_corrupt(a, 0, len(a), p, raw_frames)
            else:
                raise NotImplementedError()
        if dest:
            with open(dest, 'wb') as g:
                g.write(a)
    corruption_info = {
        'corrupted_iframe': corrupted_iframe,
        'locations': locations,
    }
    return corruption_info


def get_codec(path, encoding='utf-8'):
    codec = subprocess.check_output(('ffprobe', '-loglevel', 'quiet', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=codec_name', '-of', 'default=noprint_wrappers=1:nokey=1', os.path.expanduser(path)), stderr=subprocess.STDOUT) 
    codec = codec.strip().decode(encoding)
    return codec

def get_nal_intervals(a):
    matches = [(m.start(), m.end()) for m in re.finditer(nal_start, a)]
    return matches

def get_iframe_locations(path, codec, encoding='utf-8', loglevel=0):
    if loglevel > 0: print("\u001b[34;1m[INFO] File {} has codec {}\u001b[0m".format(path, codec), file=sys.stderr)

    if loglevel > 0: print("\u001b[34;1m[INFO] {} I-frame(s) detected by ffprobe\u001b[0m".format(ffprobe_iframes.strip().decode(encoding)), file=sys.stderr)
    iframes = []
    with open(path, 'rb') as f:
        a = bytearray(f.read())
        if codec == 'mpeg4':
            end = a.find(movi_end)
            first_frame = a.find(frame_start, a.find(movi_start), end)
            start = a.find(vop_start, first_frame, end)
            while start != -1 and start < end:
                if not (a[start + vop_offset] & vop_code_mask) >> vop_right_shift:
                    # then it's an iframe
                    # first get the size of the iframe
                    corrupt_end =  a.find(vop_start, start + vop_offset, end) 
                    if corrupt_end == -1: corrupt_end = end 
                    corrupt_start = start + 8 # skip the VOP start header and the 4-byte metadata
                    iframes.append((corrupt_start, corrupt_end))
                start = a.find(vop_start, start + vop_offset, end)
        elif codec == 'h264':
            # haha ffprobe go brrrrrr
            offsets = subprocess.check_output(('ffprobe', '-loglevel', 'quiet','-skip_frame', 'nokey', '-select_streams', 'v:0', '-show_entries', 'frame=pkt_pos,pkt_size', '-of', 'csv', path), stderr=subprocess.DEVNULL)
            # TODO: make this bit less ugly
            raw_iframe_loc_text = [row for row in offsets.strip().split(b"\n") if len(row)]
            try:
                iframelist = sorted([[int(''.join(c for c in n.decode('utf-8') if c.isdigit())) for n in row.split(b",")[1:3]] for row in raw_iframe_loc_text], key=itemgetter(0))
                temp = list(iframelist)
            except ValueError as e:
                print("Raw text:", raw_iframe_loc_text)
                raise ValueError(e)
            iframes = [(start + 3 * int_width, start + length - int_width) for start, length in temp] # skip 4-byte size argument
            first = iframes[0][0]
            iframes[0] = (first + 3 * int_width + int.from_bytes(a[first-3*int_width:first-2*int_width], byteorder='big', signed=False), iframes[0][1]) # first frame has header metadata 
        else:
            raise NotImplementedError()

    # assert len(iframes) == int(ffprobe_iframes), "Expected {} but got {} at indices [".format(int(ffprobe_iframes), len(iframes)) + ", ".join(["0x{:08x}".format(n) for n in iframes])+ "]" # DEV ONLY
    return iframes

def contiguous_corrupt(arr, start, end, p, raw_frames):
    a = random.randrange(start, int(end - p * (end - start)) + 1)
    b = a + int(p * (end - start))
    rand = os.urandom(b-a)
    arr[a:b] = rand
    locations = [a, b] 
    iframe = False
    for i_start, i_end in raw_frames:
        if i_start <= a <= i_end or i_start <= b <= i_end or a <= i_start <= i_end <= b:
            iframe = True
            break
    return arr, locations, iframe

def random_corrupt(arr, start, end, p, raw_frames):
    bitstring = np.random.binomial(1, p, size=(end - start) * 8)
    locations = np.where(bitstring == 1)[0]
    bitstring = np.packbits(bitstring, bitorder='big')
    arr[start:end] = (intify(arr[start:end]) ^ intify(bytearray(bitstring))).to_bytes(byteorder='big', length=end-start)
    iframe = False
    for i_start, i_end in raw_frames:
        corr = np.where((locations >= i_start) & (locations <= i_end))
        if len(corr[0]):
            iframe = True
            break
    return arr, locations, iframe

def range_diff(r1, r2):
    s1, e1 = r1
    s2, e2 = r2
    endpoints = sorted((s1, s2, e1, e2))
    result = []
    if endpoints[0] == s1:
        result.append((endpoints[0], endpoints[1]))
    if endpoints[3] == e1:
        result.append((endpoints[2], endpoints[3]))
    return result

def multirange_diff(r1_list, r2_list):
    for r2 in r2_list:
        r1_list = list(itertools.chain(*[range_diff(r1, r2) for r1 in r1_list]))
    return r1_list

def whack_mpeg_iframes(path, dest, mode='random', p=1.):
    if mode not in ['random', 'contiguous']: raise ValueError("Keyword 'mode' must be in '['random, 'contiguous'] but got {}".format(mode))
    # assert video is mpeg
    if get_ext(dest) != get_ext(path):
        warnings.warn("Different extensions detected! Source is of type '{}' while destination is of type '{}'".format(get_ext(path), get_ext(dest)))
    ffprobe = subprocess.Popen(('ffprobe', '-loglevel', 'quiet', 
        '-select_streams', 'v', '-skip_frame', 'nokey', '-show_frames', '-show_entries', 
        'frame=pict_type', '-of', 'csv', os.path.expanduser(path)), stdout=subprocess.PIPE)
    ffprobe_iframes = subprocess.check_output(('wc', '-l'), stdin=ffprobe.stdout)
    ffprobe.wait()
    iframes = []
    with open(path, 'rb') as f:
        a = bytearray(f.read())
        raw_frames = get_iframe_locations(path, get_codec(path))
        bad_intervals = get_nal_intervals(a)
        frames = multirange_diff(raw_frames, bad_intervals)

        for i_start, i_end in frames:
            if p == 0: break
            if mode == 'contiguous':
                if i_end - i_start <= 1 / p: 

                    continue
                a, _ = contiguous_corrupt(a, i_start, i_end, p, raw_frames)
            elif mode == 'random':
                a, _ = random_corrupt(a, i_start, i_end, p, raw_frames)
            if dest: 
                with open(dest, 'wb') as g: 
                    g.write(a)        
 
if __name__ == '__main__':
    import sys
    import os
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from options import get_corrupter_args
    args = get_corrupter_args()
    successes = 0
    orig_codec = get_codec(args.source)
    path = args.source
    temp = "temp.mp4"
    if orig_codec != 'h264': 
        print("Transcoding to H.264 (libx264)")
        os.system('ffmpeg -loglevel quiet -i {} -vcodec libx264 -strict -2 {}'.format(shlex.quote(args.source), temp))
        path = temp
    assert get_codec(temp) == 'h264'
    orig_codec = get_codec(temp)
    for i in tqdm(range(args.attempts)):
        if args.corruption == 'fill':
            assert 0 <= args.fill_val < 256
            fill_corrupt(path, args.dest, args.start, args.length, pad_val=args.fill_val)
        elif args.corruption == 'flip':
            corrupted_iframe = flip(path, args.dest, mode=args.mode, p=args.p)

        elif args.corruption == 'whack':
            whack_mpeg_iframes(path, args.dest, mode=args.mode, p=args.p)
        else:
            raise NotImplementedError
        if not args.dest: break
        child = subprocess.Popen(['ffprobe', '-loglevel', 'quiet', args.dest])
        child.communicate()
        if child.returncode == 0: 
            if get_codec(args.dest) == orig_codec: 
                successes += 1
                if args.stop_on_success: break
    if args.dest:

        if corrupted_iframe: print("hit iframe")
        print("Saved corrupted video to", args.dest)
    print("{}/{} (successes/attempts)".format(successes, i + 1))
