import os
import json
import random
import itertools
import glob
from PIL import Image
import cv2
import numpy as np
from common import trace_object
import shutil

COLORS = [(255, 64, 64), (0, 0, 255), (127, 255, 0), (255, 97, 3), (220, 20, 60),
          (255, 185, 15), (255, 20, 147), (255, 105, 180), (60, 179, 113)]

def construct_mini_dataset(label_dir, bbox_dir, num, save_path, phase="train"):
    if phase == "train":
        label_path = os.path.join(label_dir, "train.json")
    elif phase == "valid":
        label_path = os.path.join(label_dir, "validation.json")
    else:
        label_path = os.path.join(label_dir, "test.json")
    labels = json.load(open(label_path, 'r'))

    # TODO: better sampling strategy for better data distribution
    sampled_labels = random.sample(labels, int(num * 1.1))

    all_bboxes = {}
    for f in os.listdir(bbox_dir):
        bboxes = json.load(open(os.path.join(bbox_dir, f), 'r'))
        all_bboxes.update(bboxes)

    new_labels = []
    for label in sampled_labels:
        if label['id'] not in all_bboxes:
            continue

        new_label = label
        bboxes = all_bboxes[label['id']]
        new_label['bboxes'] = bboxes
        new_labels.append(new_label)

    new_labels = new_labels[:num]
    json.dump(new_labels, open(save_path, 'w'))
    return new_labels


def annotate_frame(image, annotations, color_map, annotate_objects):
    for annotation in annotations:
        x1, x2, y1, y2 = annotation['box2d']['x1'], annotation['box2d']['x2'], annotation['box2d']['y1'], annotation['box2d']['y2']
        image = cv2.rectangle(
            image, tuple([int(x1), int(y1)]), tuple([int(x2), int(y2)]), color_map[annotation['category']], 4)
        if annotate_objects:
            cv2.putText(image, annotation['category'], (int(x1)+5, int(y1)+20), cv2.FONT_HERSHEY_DUPLEX, 0.8, (0, 0, 0), 1)
    return image

def annotate_frame_simp(image, annotation):
    x1, x2, y1, y2 = annotation['box2d']['x1'], annotation['box2d']['x2'], annotation['box2d']['y1'], annotation['box2d']['y2']
    image = cv2.rectangle(
        image, tuple([int(x1), int(y1)]), tuple([int(x2), int(y2)]), (255, 64, 64), 4)
    return image

def annotate_frame_preds(image, bboxes, pred_frame_labels):

    image = cv2.copyMakeBorder(image,40,100,40,40,cv2.BORDER_CONSTANT,value=[256,256,256])
    for oid, bbox in bboxes.items():
        x1, x2, y1, y2 = bbox['x1'] + 40, bbox['x2'] + 40 , bbox['y1'] + 40, bbox['y2'] + 40
        image = cv2.rectangle(
            image, tuple([int(x1), int(y1)]), tuple([int(x2), int(y2)]), (255, 64, 64), 4)
        cv2.putText(image, str(oid), (int(x1), int(y1)), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)

    offset = 5
    current_obj = -1
    obj_pred_ct = {}
    total_binary = 0
    for pred in pred_frame_labels:
        # This is static / unary
        if len(pred) == 4:
            (predicate, prob, obj, frame) = pred

            if not obj in obj_pred_ct:
                obj_pred_ct[obj] = 1
            else:
                obj_pred_ct[obj] += 1

            text = f"{prob:.3f}: {predicate}"
            bbox = bboxes[obj]
            x1, x2, y1, y2 = bbox['x1'] + 40, bbox['x2'] + 40, bbox['y1'] + 40, bbox['y2'] + 40
            cv2.putText(image, text, (int(x1)+3, int(y1)+3*obj_pred_ct[obj]*offset), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)

        if len(pred) == 5:
            (predicate, prob, from_obj, to_obj, frame) = pred

            total_binary += 1

            text = f"{prob:.3f}: {predicate}({from_obj}, {to_obj})"
            cv2.putText(image, text, (20, 240+40+3*total_binary*offset), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)


    return image

def annotate_frame_preds_large(image, bboxes, pred_frame_labels, action_name, frame_id):


    image = cv2.copyMakeBorder(image,40,600,40,40,cv2.BORDER_CONSTANT,value=[256,256,256])
    cv2.putText(image, f"{action_name}:     {frame_id}", (20, 10), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)

    for oid, bbox in bboxes.items():
        x1, x2, y1, y2 = bbox['x1'] + 40, bbox['x2'] + 40 , bbox['y1'] + 40, bbox['y2'] + 40
        image = cv2.rectangle(
            image, tuple([int(x1), int(y1)]), tuple([int(x2), int(y2)]), (255, 64, 64), 1)
        cv2.putText(image, str(oid), (int(x1), int(y1)), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)

    offset = 5
    current_obj = -1
    # obj_pred_ct = {}
    total_predicates = 0

    for pred in pred_frame_labels:
        # This is static / unary
        if len(pred) == 4:
            (predicate, prob, obj, frame) = pred
            total_predicates += 1

            text = f"{prob:.3f}: {predicate}({obj})"
            bbox = bboxes[obj]
            x1, x2, y1, y2 = bbox['x1'] + 40, bbox['x2'] + 40, bbox['y1'] + 40, bbox['y2'] + 40
            cv2.putText(image, text, (20, 320+3*total_predicates*offset), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)

        if len(pred) == 5:
            (predicate, prob, from_obj, to_obj, frame) = pred

            total_predicates += 1

            text = f"{prob:.3f}: {predicate}({from_obj}, {to_obj})"
            cv2.putText(image, text, (20, 320+3*total_predicates*offset), cv2.FONT_HERSHEY_DUPLEX, fontScale=0.4, color=(255, 64, 64), thickness=1)


    return image

def annotate_pred_video(video, out_vid_path, bboxes, pred_labels, action_name):

    frame = video[0]
    frame_width = frame.shape[1]
    frame_height = frame.shape[0]
    size = (frame_width + 80, frame_height + 640)
    result = cv2.VideoWriter(out_vid_path, cv2.VideoWriter_fourcc(*'mp4v'), 12, size)

    for frame_id, pred_frame_labels in pred_labels.items():
        frame = video[frame_id]
        im = annotate_frame_preds_large(frame, bboxes[frame_id], pred_frame_labels, action_name, frame_id)
        result.write(im)

    # release the cap object
    result.release()

    # close all windows
    cv2.destroyAllWindows()


def annotate_video(video_path, out_vid_path, annotation, color_map, annotate_objects):

    assert (os.path.exists(video_path))
    cap = cv2.VideoCapture(video_path)
    video = []

    while(cap.isOpened()):

        # Capture frames in the video
        ret, frame = cap.read()
        if ret == True:
            video.append(frame)
        else:
            break

    frame = video[0]
    frame_width = frame.shape[1]
    frame_height = frame.shape[0]
    size = (frame_width, frame_height)
    result = cv2.VideoWriter(out_vid_path, cv2.VideoWriter_fourcc(*'mp4v'), 12, size)

    if int(annotation[-1]['name'][:-4].split('/')[1]) - 1 >= len(video):
        return

    for anno in annotation:
        frame_id = int(anno['name'][:-4].split('/')[1]) - 1
        if frame_id >= len(video):
            break
        frame = video[frame_id]
        im = annotate_frame(frame, anno['labels'], color_map, annotate_objects)
        result.write(im)

    # release the cap object
    cap.release()
    result.release()
    # close all windows
    cv2.destroyAllWindows()

class DataChecker():
    def __init__(self, dataset_path, video_dir, video_save_dir):
        self.dataset = json.load(open(dataset_path, 'r'))
        self.video_dir = video_dir
        self.video_save_dir = video_save_dir
        assert os.path.exists(self.video_save_dir)

    def get_colormap(self, meta):
        all_objects = [[j['category'] for j in i['labels']] for i in meta]
        all_objects = np.unique(list(itertools.chain.from_iterable(all_objects)))
        color_map = {all_objects[i]: COLORS[i] for i in range(len(all_objects))}
        return color_map

    def annotate_video(self, n):
        datapoint = self.dataset[n]
        vid_id = datapoint['id']
        video_path = os.path.join(self.video_dir, f"{vid_id}.webm")
        output_video_path = os.path.join(self.video_save_dir, f"{vid_id}.mp4")
        color_map = self.get_colormap(datapoint['bboxes'])
        annotate_video(video_path, output_video_path, datapoint['bboxes'], color_map, annotate_objects=True)

    def check_datapoint(self, dataset_path, n):
        datapoints = json.load(open(dataset_path))
        datapoint = datapoints[n]

def add_object_tracking(dataset):
    for dp_ct, dp in enumerate(dataset):
        print(dp_ct)
        trace_object(dp['bboxes'])
    return dataset

def collect_videos(dataset, orig_video_dir, mini_video_dir):
    for dp in dataset:
        vid_id = dp['id']
        file_name =  f"{vid_id}.webm"
        video_path = os.path.join(orig_video_dir, file_name)
        new_video_path = os.path.join(mini_video_dir, file_name)
        shutil.copy(video_path, new_video_path)

if __name__ == "__main__":
    prefix = "raw"
    phase = "train"
    seed = 1234
    num = 10000
    random.seed(seed)

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../data'))
    label_dir = os.path.join(data_dir, 'labels')
    video_dir = os.path.join(data_dir, '20bn-something-something-v2')
    bbox_dir = os.path.join(data_dir, 'bboxes')
    video_save_dir = os.path.join(data_dir, 'debug-videos')

    mini_dataset_path = os.path.join(data_dir, f"clean_train_10000.json")
    data_checker = DataChecker(mini_dataset_path, video_dir, video_save_dir)
    for i in range(10):
        if i == 0:
            continue
        data_checker.annotate_video(i)

    print('here')
