import numpy as np
import math
from collections import defaultdict
import torchvision.transforms as T
import os
import cv2

bool_token = ["and", "or", "not"]
kw_preds = {
    '=': "=="
}
not_kw_preds = {
    '=': "!="
}

var2vid = {
    'a': 1,
    'b': 2,
    'c': 3,
}

const2cid = {
    'HAND': -1
}

action_arg_num = {"approach":1, "attach": 2, "bend-deform": 1, "bend-break": 1, "bury": 2, "close": 1, "cover": 2, "dig": 2, "drop-behind": 2, "drop-infront": 2, "drop-into": 2, "drop-nextto": 2, "drop-onto": 2, "put-into-fail-fit": 2, "fold": 1, "hit": 2, "hold": 1, "hold-behind": 2, "hold-infront": 2, "hold-nextto": 2, "hold-over": 2, "lay-side": 1, "let-roll": 1, "roll-slanted": 1, "roll-up-down": 1, "lift-surface": 1, "lift-surface-sliding": 1, "lift": 1, "lift-drop": 1, "lift-with": 2, "lift-end": 1, "lift-end-drop": 1, "move-camera-away": 1, "move-part": 2, "move-across-fall": 1, "move-across": 1, "move-away": 2, "move-closer": 2, "move-collide": 2, "move-pass": 2, "move-one-away": 2, "move-away-from": 1, "move-closer-to": 2, "move-down": 1, "move-towards": 1, "move-up": 1, "open": 1, "pick": 1, "pile": 0, "plug": 2, "plug-pull": 2, "poke-hole-substance": 1, "poke-hole-soft": 1, "poke-stack-collapse": 1, "poke-stack": 1, "poke-move": 1, "poke": 1, "poke-fall": 1, "poke-spin": 1, "pour-into": 2, "pour-into-overflow": 2, "pour-onto": 2, "pour-out": 2, "wipe-fail": 2, "twist-fail": 1, "tear-fail": 1, "close-fail": 1, "open-fail": 1, "pick-fail": 1, "poke-fail": 1, "pour-fail": 2, "put-behind-fail": 2, "put-into-fail": 2, "put-nextto-fail": 2, "put-onsurface-fail": 1, "put-onto-fail": 2, "put-under-fail": 2, "scoop-fail": 2, "spread-fail": 1, "sprinkle-fail": 1, "squeeze-fail": 1, "take-fail": 2, "take-out-fail": 2, "throw-fail": 1, "turn-fail": 1, "pull-behind": 2, "pull-right": 1, "pull-left": 1, "pull-onto": 2, "pull-out": 2, "pull-ends": 1, "pull-stretch": 1, "pull-break": 1, "push-right": 1, "push-left": 1, "push-off": 2, "push-onto": 2, "push-spin": 1, "push-edge": 1, "push-fall": 1, "push": 1, "push-with": 2, "put-many-onto": 0, "put-two": 2, "put-behind": 2, "put-infront": 2, "put-into": 2, "put-nextto": 2, "put-onsurface-noroll": 1, "put-onsurface": 1, "put-edge-fall": 2, "put-slanted": 1, "put-onto": 2, "put-fall": 2, "put-similar": 1, "put-slide": 1, "put-slanted-noslide": 1, "put-side": 1, "put-under": 2, "put-upright": 1, "put-three": 3, "remove-reveal": 2, "roll": 1, "scoop": 2, "show-photo": 1, "show-behind": 2, "show-nextto": 2, "show-on": 2, "show": 1, "show-empty": 1, "show-in": 2, "deflect": 2, "collide-deflect": 2, "collide-halt": 2, "drop-light": 1, "drop-heavy": 1, "spill-behind": 2, "spill-nextto": 2, "spill-on": 2, "spin-long": 1, "spin-short": 1, "spread": 2, "sprinkle": 2, "squeeze": 1, "stack": 0, "stuff": 2, "take-one": 1, "take-from": 2, "take-out": 2, "tear-pieces": 1, "tear": 1, "throw": 1, "throw-against": 2, "throw-catch": 1, "throw-fall": 1, "throw-surface": 1, "tilt": 2, "tilt-fall": 2, "tip": 1, "tip-fall-out": 2, "touch": 1, "attach-fail": 2, "bend-fail": 1, "pour-miss": 2, "turn-upsidedown": 1, "turn-camera-down": 1, "turn-camera-left": 1, "turn-camera-right": 1, "turn-camera-up": 1, "twist-water": 1, "twist": 1, "uncover": 1, "unfold": 1, "wipe": 2}
eps = 0.0001

num_to_actions = {}
for action, num in action_arg_num.items():
    if not num in num_to_actions:
        num_to_actions[num] = []
    num_to_actions[num].append(action)

static_preds = ["is_bendable", "is_fluid", "is_holdable", "is_rigid", "is_spreadable", "is_tearable"]
mutable_preds = ["above", "attached", "behind", "broken", "close", "closed", "deformed", "empty", "far", "fits", "folded", "full", "has_hole", "high", "in", "infront", "left", "low", "nextto", "on", "onsurface", "open", "right", "stacked", "stretched", "torn", "touching", "twisted", "under", "upright", "visible"]
binary_preds = ["above", "attached", "behind", "in", "fits", "infront", "nextto", "on", "touching", "under"]
unary_preds = ["broken", "close", "closed", "deformed", "empty", "far", "folded", "full", "has_hole", "high", "left", "low", "onsurface", "open", "right", "stacked", "stretched", "torn", "twisted", "upright", "visible"]
non_prob_preds = ["frame", "all_frames", "all_objects", "num_variables", "precondition", "effect", "variable_name", "positive_unary_atom", "negative_unary_atom", "positive_binary_atom", "negative_binary_atom", "inequality_constraint"]

image_mean = (114.27205934, 111.62876119, 106.06272127)
image_std =  (66.80977908, 65.11730349, 65.42803173)
transform = T.Normalize(mean = image_mean, std = image_std)

norm_x = 420
norm_y = 240
channels = 3
timeout = 120 # Time out the programs requires more than 2 minutes to solve

def get_video_size(video_path):
    height = -1
    width = -1
    vcap = cv2.VideoCapture(video_path) #0 for camera
    if vcap.isOpened():
        width  = vcap.get(cv2.CAP_PROP_FRAME_WIDTH)   # float `width`
        height = vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float `height`
    return height, width

def is_valid_bbox(x1, y1, x2, y2):
    if x1 > norm_x or x2 > norm_x:
        return False
    if y1 > norm_y or y2 > norm_y:
        return False
    if abs(y2 - y1) * abs(x2 - x1) == 0:
        return False
    return True

norm_video_shape = (norm_y, norm_x, channels)

action_arg_names = ["action_a_name", "action_b_name", "action_c_name"]

def get_area(x1, x2, y1, y2):
    return (x2 - x1) * (y2 - y1)

def is_same_bbox(bbox1, bbox2, threshold=1.4):
    x11 = bbox1['x1']
    x12 = bbox1['x2']
    y11 = bbox1['y1']
    y12 = bbox1['y2']
    w1 = x12 - x11 + 0.00001
    h1 = y12 - y11 + 0.00001

    x21 = bbox2['x1']
    x22 = bbox2['x2']
    y21 = bbox2['y1']
    y22 = bbox2['y2']
    w2 = x22 - x21 + 0.00001
    h2 = y22 - y21 + 0.00001

    overlap_x1 = max(x11, x21)
    overlap_x2 = min(x12, x22)
    overlap_y1 = max(y11, y21)
    overlap_y2 = min(y12, y22)

    # #Using area for calculation cannot capture the edge bboxes
    a1 = get_area(x11, x12, y11, y12)
    a2 = get_area(x21, x22, y21, y22)
    ao = get_area(overlap_x1, overlap_x2, overlap_y1, overlap_y2)
    # if (2 * ao) / (a1 + a2) > 3 * threshold:
    #     return True

    # Center
    c1_x = (x11 + x12) / 2
    c1_y = (y11 + y12) / 2
    c2_x = (x21 + x22) / 2
    c2_y = (y21 + y22) / 2
    dcx = abs(c2_x - c1_x)
    dcy = abs(c2_y - c1_y)

    dx1 = abs(x21 - x11)
    dy1 = abs(y21 - y11)
    dx2 = abs(x22 - x12)
    dy2 = abs(y22 - y21)

    dcx_percent = 2 * dcx / (w1 + w2)
    dcy_percent = 2 * dcy / (h1 + h2)
    dx1_percent = 2 * dx1 / (w1 + w2)
    dy1_percent = 2 * dy1 / (h1 + h2)
    dx2_percent = 2 * dx2 / (w1 + w2)
    dy2_percent = 2 * dy2 / (h1 + h2)

    assert dcx_percent >= 0 and dcy_percent >= 0
    small_threshold = 0.1 * threshold

    overall_sim = (dcx_percent < threshold and dcy_percent < threshold and \
       dx1_percent < 1.5 * threshold and dy1_percent < 1.5 * threshold and \
       dx2_percent < 1.5 * threshold and dy2_percent < 1.5 * threshold)

    one_spec_sim = (dcx_percent < small_threshold or dcy_percent < small_threshold or \
       dx1_percent < 1.5 * small_threshold or dy1_percent < 1.5 * small_threshold or \
       dx2_percent < 1.5 * small_threshold or dy2_percent < 1.5 * small_threshold)

    sim = 2 - (dcx_percent + dcy_percent + dx1_percent + dy1_percent + dx2_percent + dy2_percent) / 6
    # if sim < 0:
    #     print('here')
    if overall_sim or one_spec_sim:
        return True, sim
    return False, sim

def list2dict(object_ids):
    object_dict = {}
    for obj_idx, object_id in enumerate(object_ids):
        if not object_id in object_dict:
            object_dict[object_id] = []
        object_dict[object_id].append(obj_idx)
    return object_dict

def solve_conflicts(object_to_choices):
    if len(object_to_choices) == 0:
        return {}

    names_to_objs = {}
    for obj, candicates in object_to_choices.items():
        for name, sim in candicates.items():
            if not name in names_to_objs:
                names_to_objs[name] = []
            names_to_objs[name].append((sim, obj))

    max_len = -1
    for name in names_to_objs.keys():
        names_to_objs[name] = sorted(names_to_objs[name], reverse=True)
        if len(names_to_objs[name]) > max_len:
            max_len = len(names_to_objs[name])

    idx_to_name = {}

    while not max_len <= 0:
        current_max_len = -1
        current_max_likely = (None, -math.inf)

        # Find the most likely object
        for name, info in names_to_objs.items():
            sim = info[0][0]
            oid = info[0][1]
            if sim > current_max_likely[1]:
                current_max_likely = (name, oid)


        (name, oid) = current_max_likely

        idx_to_name[oid] = name
        names_to_objs.pop(name)
        empty_names = []

        for name, info in names_to_objs.items():
            new_info = []
            for (sim, oid) in info:
                if not oid in idx_to_name:
                    new_info.append((sim, oid))

            if len(new_info) == 0:
                empty_names.append(name)
            if len(new_info) > current_max_len:
                current_max_len = len(new_info)

            names_to_objs[name] = new_info

        for name in empty_names:
            names_to_objs.pop(name)

        max_len = current_max_len

    remaining_idxes = set(object_to_choices.keys()) - set(idx_to_name.keys())
    return idx_to_name, remaining_idxes

def assign_oid(names):
    name_to_idx = {}
    current_idx = 0

    for name in names:
        name_to_idx[name] = current_idx
        current_idx += 1

    return name_to_idx

def trace_object(bboxes_info):
    all_objects = {}
    object_instances = {}

    for frame_id, frame in enumerate(bboxes_info):
        all_object_ids = [label['gt_annotation'] for label in frame['labels']]
        all_object_dict = list2dict(all_object_ids)

        waitlists = {}
        occurred_names = []

        for label_name, object_idx in all_object_dict.items():
            object_num = len(object_idx)
            object_infos = [(i, frame['labels'][i]) for i in object_idx]

            for frame_oid, object_info in object_infos:
                object_bbox = object_info['box2d']
                already_occurred = False

                if label_name in object_instances:
                    alternative_names = object_instances[label_name]
                    similarities = {}

                    for name in alternative_names:

                        prev_frame_id, oid, prev_bboxes = all_objects[name][-1]
                        same_bbox, similarity = is_same_bbox(object_bbox, prev_bboxes)

                        if same_bbox and not prev_frame_id == frame_id:
                            already_occurred = True
                            similarities[name] = similarity

                    if already_occurred:
                        if not label_name in waitlists:
                            waitlists[label_name] = []
                        waitlists[label_name].append((frame_id, frame_oid, object_bbox, similarities))

                else:
                    object_instances[label_name] = []

                if not already_occurred:
                    new_name = label_name + '_' + str(len(object_instances[label_name]))
                    object_instances[label_name].append(new_name)
                    all_objects[new_name] = []
                    all_objects[new_name].append((frame_id, frame_oid, object_bbox))


        for label_name, waitlist in waitlists.items():

            object_to_choices = {}
            all_norm_vals = []
            norm_cts = []
            current_ct = 0
            for oid, (frame_id, oid1, bbox1, similarities) in enumerate(waitlist):
                object_to_choices[oid] = {}
                candidates = set(similarities.keys())
                potential_candidates = list(candidates - set(occurred_names))
                values = np.array([similarities[p] for p in potential_candidates])
                all_norm_vals.append(values)
                current_ct += len(values)
                norm_cts.append(current_ct)
            all_norm_vals = np.concatenate(all_norm_vals)
            normalized_values = (all_norm_vals + 0.0001) / (sum(all_norm_vals) + 0.0001)

            current_norm_ct = 0
            for oid, norm_ct in enumerate(norm_cts):
                normalized_val = normalized_values[current_norm_ct: current_norm_ct + norm_ct]
                for n, v in zip(potential_candidates, normalized_val):
                    object_to_choices[oid][n] = v
                current_norm_ct += 1

            idx_to_name, remaining_idxes = solve_conflicts(object_to_choices)
            assert len(idx_to_name) == len(set(idx_to_name))
            for idx, name in idx_to_name.items():
                frame_id, oid, object_bbox, similarities = waitlist[idx]
                all_objects[name].append((frame_id, oid, object_bbox))

            for idx in remaining_idxes:
                frame_id, oid, object_bbox, similarities = waitlist[idx]
                new_name = label_name + '_' + str(len(object_instances[label_name]))
                object_instances[label_name].append(new_name)
                all_objects[new_name] = []
                all_objects[new_name].append((frame_id, oid, object_bbox))

    name_to_idx = assign_oid(all_objects.keys())
    for obj_name, obj_info in all_objects.items():
        frames = [f for f, _, _ in obj_info]
        frames_set = set(frames)
        if not len(frames_set) == len(frames):
            raise Exception('wrong')

        for f, oid, _ in obj_info:
            current_obj = bboxes_info[f]['labels'][oid]
            current_obj_idx = name_to_idx[obj_name]
            current_obj['id'] = current_obj_idx

    for f, bbox in enumerate(bboxes_info):
        for i, label in enumerate(bbox['labels']):
            assert 'id' in label

    return bboxes_info

#This class represents a directed graph using adjacency list representation
class Graph:

    def __init__(self,vertices):
        self.V= vertices #No. of vertices
        self.graph = defaultdict(list) # default dictionary to store graph

    # function to add an edge to graph
    def addEdge(self,u,v):
        self.graph[u].append(v)

    # A function used by DFS
    def DFSUtil(self,v,visited):
        visiting = [v]
        # Mark the current node as visited and print it
        visited[v]= True

        #Recur for all the vertices adjacent to this vertex
        for i in self.graph[v]:
            if visited[i]==False:
                visiting += (self.DFSUtil(i,visited))

        return visiting

    def fillOrder(self,v,visited, stack):
        # Mark the current node as visited
        visited[v]= True
        #Recur for all the vertices adjacent to this vertex
        for i in self.graph[v]:
            if visited[i]==False:
                self.fillOrder(i, visited, stack)
        stack = stack.append(v)


    # Function that returns reverse (or transpose) of this graph
    def getTranspose(self):
        g = Graph(self.V)

        # Recur for all the vertices adjacent to this vertex
        for i in self.graph:
            for j in self.graph[i]:
                g.addEdge(j,i)
        return g



    # The main function that finds and prints all strongly
    # connected components
    def getSCCs(self):

        stack = []
        # Mark all the vertices as not visited (For first DFS)
        visited =[False]*(self.V)
        # Fill vertices in stack according to their finishing
        # times
        for i in range(self.V):
            if visited[i]==False:
                self.fillOrder(i, visited, stack)

        # Create a reversed graph
        gr = self.getTranspose()

        # Mark all the vertices as not visited (For second DFS)
        visited =[False]*(self.V)

        groups = []
        # Now process all vertices in order defined by Stack
        while stack:
            i = stack.pop()
            if visited[i]==False:
                visiting = gr.DFSUtil(i, visited)
                groups.append(visiting)

        return groups

if __name__ == "__main__":

    g = Graph(5)
    g.addEdge(1, 0)
    g.addEdge(0, 2)
    g.addEdge(2, 1)
    g.addEdge(0, 3)
    g.addEdge(3, 4)


    print ("Following are strongly connected components " +
                            "in given graph")
    r = g.getSCCs()
    print('here')

def read_video(annotation, video_path):

    assert (os.path.exists(video_path))
    cap = cv2.VideoCapture(video_path)
    video = []

    while(cap.isOpened()):

        # Capture frames in the video
        ret, frame = cap.read()
        if ret == True:
            video.append(frame)
        else:
            break

    if int(annotation[-1]['name'][:-4].split('/')[1]) - 1 >= len(video):
        print('invalid video')
        return

    annotated_video = []
    for anno in annotation:
        frame_id = int(anno['name'][:-4].split('/')[1]) - 1
        if frame_id >= len(video):
            break
        frame = video[frame_id]
        annotated_video.append(frame)

    return annotated_video
