import os
import pdb 
import json
import pickle
import imageio
import argparse
import multiprocessing

from os import path
from tqdm import tqdm
from functools import partial


# input: track_1, track_2
# for each track: start_frame, end_frame
def get_overlaped_period(track_1, track_2):
    start_frame_1, end_frame_1 = track_1[0], track_1[1]
    start_frame_2, end_frame_2 = track_2[0], track_2[1]
    if start_frame_1 >= end_frame_2 or start_frame_2 >= end_frame_1:
        return None, None
    start_frame = max(start_frame_1, start_frame_2)
    end_frame = min(end_frame_1, end_frame_2)
    return start_frame, end_frame

def get_track(clip, bbox_track, index, period_index):
    result = {}
    result['period_index'] = period_index
    result['human_id'] = index
    result['period'] = clip[-1]
    result['boxes'], result['frame'] = [], []
    track_frame = bbox_track[index]['track']['frame'] # here we assume the frame are continuous
    track_bboxes = bbox_track[index]['track']['bbox']
    for frame_index in range(len(track_frame)):
        if  result['period'][0] <= track_frame[frame_index] <=  result['period'][1]:
            result['boxes'].append(track_bboxes[frame_index])
            result['frame'].append(track_frame[frame_index])
    return result
     
def get_tracks_with_clip(data, args):
    file_name = data
    croped_video_save_path = path.join(args.croped_video_save_path, file_name)
    if os.path.isdir(croped_video_save_path): return
    else: os.mkdir(croped_video_save_path)

    ## load data
    video_selection_path = path.join(args.video_selection_path, file_name)
    with open(path.join(video_selection_path, 'vaild_video_period.json'), 'r', encoding='utf-8') as f: video_selection = json.load(f)
    with open(path.join(video_selection_path, 'track_listening.json'), 'r', encoding='utf-8') as f: track_listening = json.load(f)
    with open(path.join(video_selection_path, 'track_speaking.json'), 'r', encoding='utf-8') as f: track_speaking = json.load(f)
    video_selection = video_selection['vaild_video_period']
    track_listening, track_speaking = track_listening['track_listening'],  track_speaking['track_speaking']
    
    bbox_track_path = path.join(args.seg_speaker_path, file_name, 'result', 'tracks.pckl')
    with open(bbox_track_path, 'rb') as f:
        bbox_track = pickle.load(f)
    ## get vaild period face boxes and start, end frame
    clip_list_all = []
    for selection in video_selection:
        clip_list = []
        period_start, period_end = selection['period'][0], selection['period'][1]
        track_index = selection['track_index']
        if period_end - period_start < args.min_frame_period:
            continue
        speaking_period_list = []
        for index in track_index:
            speaking_period_list.append(track_speaking[index])
        num_speaker = len(speaking_period_list)

        speaking_dic = {}
        for index in range(num_speaker):
            track_length = len(speaking_period_list[index])
            track = speaking_period_list[index]
            for frame_index in range(track_length):
                center_frame = (track[frame_index][0]+track[frame_index][1])/2
                speaking_dic[center_frame] = {'speaker': track_index[index], 'period': [track[frame_index][0], track[frame_index][1]]}
        for key in sorted(speaking_dic): 
            start, end = get_overlaped_period(selection['period'], speaking_dic[key]['period'])
            if start is None: 
                continue # within vaild period
            else: 
                clip_list.append([speaking_dic[key]['speaker'], track_index, key, [start, end]])
        
        for index in range(len(clip_list)):
            if index == 0:
                start = clip_list[index][-1][0]
                if len(clip_list) == 1:
                    end = clip_list[index][-1][1]
                else:
                    end = int((clip_list[index][-1][1]+clip_list[index+1][-1][0])/2)
            elif index == len(clip_list)-1:
                start = int((clip_list[index][-1][0]+clip_list[index-1][-1][1])/2)
                end = clip_list[index][-1][1]
            else:
                start = int((clip_list[index][-1][0]+clip_list[index-1][-1][1])/2)
                end = int((clip_list[index][-1][1]+clip_list[index+1][-1][0])/2)
            clip_list[index].append([start, end])
        clip_list_all.extend(clip_list)
    # pdb.set_trace()
    speaking_all_result = []
    listening_all_result = []
    period_index = 0 
    for clip in clip_list_all:
        # speaking
        speak_index = clip[0]
        all_index = clip[1]
        result = get_track(clip, bbox_track, speak_index, period_index)
        speaking_all_result.append(result)
        # listening 
        for index in all_index:
            if index is not speak_index:
                result = get_track(clip, bbox_track, index, period_index)
                listening_all_result.append(result)
        period_index += 1
    
    all_data = []
    for index in range(len(speaking_all_result)):
        order = str(speaking_all_result[index]['period_index'])
        all_data.append([file_name, speaking_all_result[index], order+'_speak='+str(speaking_all_result[index]['human_id'])])
    for index in range(len(listening_all_result)):
        order = str(speaking_all_result[index]['period_index'])
        all_data.append([file_name, listening_all_result[index], order+'_listen='+str(listening_all_result[index]['human_id'])])
    with open(path.join(croped_video_save_path, 'speak_order.json'), 'w') as f:     
            json.dump(clip_list_all, f) 
            
    num_processes = 32
    ctx = multiprocessing.get_context('spawn')
    with ctx.Pool(processes=num_processes) as pool:
        func = partial(process_video, args=args)
        pool.map(func, all_data)
    return
        
# return: face bbox track aligened with frame
def get_listening_track(data, args):
    file_name = data
    croped_video_save_path = path.join(args.croped_video_save_path, file_name)
    if os.path.isdir(croped_video_save_path): return
    else: os.mkdir(croped_video_save_path)
    ## load data
    video_selection_path = path.join(args.video_selection_path, file_name)
    with open(path.join(video_selection_path, 'vaild_video_period.json'), 'r', encoding='utf-8') as f: video_selection = json.load(f)
    with open(path.join(video_selection_path, 'track_listening.json'), 'r', encoding='utf-8') as f: track_listening = json.load(f)
    with open(path.join(video_selection_path, 'track_speaking.json'), 'r', encoding='utf-8') as f: track_speaking = json.load(f)
    video_selection = video_selection['vaild_video_period']
    track_listening, track_speaking = track_listening['track_listening'],  track_speaking['track_speaking']
    
    bbox_track_path = path.join(args.seg_speaker_path, file_name, 'result', 'tracks.pckl')
    with open(bbox_track_path, 'rb') as f:
        bbox_track = pickle.load(f)
    ## get vaild period face boxes and start, end frame
    all_result = []
    for selection in video_selection:
        start, end = selection['period'][0], selection['period'][1]
        track_index = selection['track_index']
        for index in track_index:
            for period in track_listening[index]:
                result = {}
                start_frame, end_frame = get_overlaped_period([start, end], period)
                if start_frame is None:
                    continue
                else:
                    result['period'] = [start_frame, end_frame]
                    result['boxes'], result['frame'] = [], []
                    track_frame = bbox_track[index]['track']['frame'] # here we assume the frame are continuous
                    track_bboxes = bbox_track[index]['track']['bbox']
                    for frame_index in range(len(track_frame)):
                        if start_frame <= track_frame[frame_index] <= end_frame:
                            result['boxes'].append(track_bboxes[frame_index])
                            result['frame'].append(track_frame[frame_index])
                all_result.append(result)        
    all_data = []
    for index in range(len(all_result)):
        all_data.append([file_name, all_result[index], index])
        process_video(all_data[index], args)
    process_video(all_data[0], args)
    # num_processes = 1
    # ctx = multiprocessing.get_context('spawn')
    # with ctx.Pool(processes=num_processes) as pool:
    #     func = partial(process_video, args=args)
    #     pool.map(func, all_data)
    # return

def bb_intersection_over_union(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp,outp, image_shape, increase_area=0.1):
    left, top, right, bot = tube_bbox
    width = right - left
    height = bot - top

    #Computing aspect preserving bbox
    width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
    height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))

    left = int(left - width_increase * width)
    top = int(top - height_increase * height)
    right = int(right + width_increase * width)
    bot = int(bot + height_increase * height)

    top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
    h, w = bot - top, right - left

    start = start / fps
    end = end / fps
    time = end - start

    scale = f'{image_shape[0]}:{image_shape[1]}'

    return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" {outp} -y'

def compute_bbox_trajectories(trajectories, fps, frame_shape, video_path, save_path, args):
    commands = []
    for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
        if (end - start) > args.min_frames:
            command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=video_path, outp=save_path, image_shape=args.image_shape, increase_area=args.increase)
            commands.append(command)
    return commands

def join(tube_bbox, bbox):
    xA = min(tube_bbox[0], bbox[0])
    yA = min(tube_bbox[1], bbox[1])
    xB = max(tube_bbox[2], bbox[2])
    yB = max(tube_bbox[3], bbox[3])
    return (xA, yA, xB, yB)

## crop the video
def process_video(data, args):
    file_name, result, idx = data[0], data[1], data[2]
    video_path = path.join(args.seg_speaker_path, file_name, 'avi', 'video.avi')
    save_path = path.join(args.croped_video_save_path, file_name, str(idx)+'.avi')
    video = imageio.get_reader(video_path)
    trajectories = []
    previous_frame = None
    fps = video.get_meta_data()['fps']
    commands = []
    
    start_frame, end_frame = result['period']
    # print(start_frame, end_frame)
    boxes = result['boxes']
    for i, frame in tqdm(enumerate(video)):
        if i > end_frame:
            break
        if start_frame <= i <= end_frame:
            frame_shape = frame.shape # (1080, 1920, 3)
            bboxes =  [boxes[i-start_frame]]
            ## For each trajectory check the criterion
            not_valid_trajectories = []
            valid_trajectories = []

            for trajectory in trajectories:
                tube_bbox = trajectory[0]
                intersection = 0
                for bbox in bboxes:
                    intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
                if intersection > args.iou_with_initial:
                    valid_trajectories.append(trajectory)
                else:
                    not_valid_trajectories.append(trajectory)

            commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, video_path, save_path, args)
            trajectories = valid_trajectories

            ## Assign bbox to trajectories, create new trajectories
            for bbox in bboxes:
                intersection = 0
                current_trajectory = None
                for trajectory in trajectories:
                    tube_bbox = trajectory[0]
                    current_intersection = bb_intersection_over_union(tube_bbox, bbox)
                    if intersection < current_intersection and current_intersection > args.iou_with_initial:
                        intersection = bb_intersection_over_union(tube_bbox, bbox)
                        current_trajectory = trajectory

                ## Create new trajectory
                if current_trajectory is None:
                    trajectories.append([bbox, bbox, i, i])
                else:
                    current_trajectory[3] = i
                    current_trajectory[1] = join(current_trajectory[1], bbox)
            commands += compute_bbox_trajectories(trajectories, fps, frame_shape, video_path, save_path, args)
    for command in commands:
        os.system(command)
    return 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "crop video")
    parser.add_argument('--num_people', type=int, default=2, help='the number of people in the conversation')
    parser.add_argument('--min_frame_period', type=int, default=125, help='min frame for a conversation clip')
    parser.add_argument('--seg_speaker_path', type=str, default='/users/zeyuzhu/dataset_project/Datasets/FallowShow/1_speaker_segmentation', help='seg_speaker_path')
    parser.add_argument('--video_selection_path', type=str, default='/users/zeyuzhu/dataset_project/Datasets/FallowShow/2_video_selection', help='video_selection_path')
    parser.add_argument('--croped_video_save_path', type=str, default='/users/zeyuzhu/dataset_project/Datasets/FallowShow/3_croped_video', help='crped_video_save_path')
    parser.add_argument("--bias", type=int, default=10,  help='intercept the middle of the start and end')
    
    parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))), help="Image shape")
    parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
    parser.add_argument("--increase", default=0.2, type=float, help='Increase bbox by this amount')
    parser.add_argument("--min_frames", type=int, default=25,  help='Minimum number of frames')
    parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
    
    args = parser.parse_args()
    
    root_seg = args.seg_speaker_path
    file_name_list = [f for f in os.listdir(root_seg) if os.path.isdir(root_seg)]
    # file_name_list = [str(index) for index in range(len(file_name_list))]
    file_name_list = [str(index) for index in range(10)]
    # failure_name = ['18']
    for file_name in tqdm(file_name_list):
        # if file_name not in failure_name:
        get_tracks_with_clip(file_name, args)
    
    