import json, os
from tqdm import tqdm
import torch
from demo.common.config.config import load_config
from demo.common.dataset.transforms import HOIBOT_Transform
from demo.common.model.HOI4ABOT import HOI4ABOT
import cv2

import rospy
import tf
import tf.transformations
#from demo_pour.msg import object, objectArray
import numpy as np
from pykinect_azure import k4a_float2_t, K4A_CALIBRATION_TYPE_COLOR
import pykinect_azure as pykinect

def prepare_viderowriter(fps, frame0, config):
    config["VIDEO"] = {"fps": fps, "height": frame0.shape[1], "width":  frame0.shape[0]}
    # output video writer
    fourcc = "mp4v"
    fps, w, h = round(fps), frame0.shape[1], frame0.shape[0]  # not handle decimal fps
    video_writer = cv2.VideoWriter(str(config["PATHS"]["result_video_file"]), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
    return video_writer, config

def get_center(bbox):
    x_c = int((bbox[0]+bbox[2])/2)
    y_c = int((bbox[1]+bbox[3])/2)
    
    return x_c, y_c

def publish_points(boxes_xyxy, names, depth, dataset, pub, t_mtx, frame_annotated):
        points = []
        obj_names = []
        for box, l in zip(boxes_xyxy, names):
            px, py = get_center(box)
            print((px, py))

            frame_annotated = cv2.circle(frame_annotated, (int(px), int(py)), 4, (255, 0, 0), -1)
            rgb_depth = depth[py, px]

            pixels = k4a_float2_t((px, py))

            print("pixels", pixels)
            # print(rgb_depth)

            # print(kinect_device.calibration.get_matrix(K4A_CALIBRATION_TYPE_COLOR))
            pos3d_color = dataset.convert_2d_to_3d(pixels, rgb_depth, K4A_CALIBRATION_TYPE_COLOR, K4A_CALIBRATION_TYPE_COLOR)
            # print(f"RGB depth: {rgb_depth}, RGB pos3D: {pos3d_color}")
            points.append([pos3d_color.xyz.x/1000, pos3d_color.xyz.y/1000, pos3d_color.xyz.z/1000, 1])
            names.append(l)

        if(len(points) != 0):
            points = np.array(points).T
            # # print(points)
            points_robot = t_mtx @ points
            points_robot = points_robot.T
            objects = []
            
            for p, l in zip(points_robot, names):
                o = object()
                o.name = l
                o.x = p[0]
                o.y = p[1]
                o.z = p[2]

                objects.append(o)

            pub.publish(objects)

def get_new_kinect_frame(video_cap, meta_info, transform=None):
        capture = video_cap.update()
        success, frame0 = capture.get_color_image()
        while not success:
            capture = video_cap.update()
            success, frame0 = capture.get_color_image()

        ret_depth, transformed_depth_image = capture.get_transformed_depth_image()
        ret_depth, transformed_colored_depth_image = capture.get_transformed_colored_depth_image()
        if transform is not None:
            frame_add = transform(frame0)
            meta_info["additional"] = frame_add
        meta_info.update({"frame_count": meta_info["frame_count"]+1, "frame_num": meta_info["frame_num"]+1})
        return frame0, transformed_depth_image, transformed_colored_depth_image, meta_info
            

@torch.no_grad()
def main(subst=None, vis=True):
    do_robot = False
    if do_robot:
        rospy.init_node("anticipation_node")
        pub = rospy.Publisher("/objects", objectArray, queue_size=1)

        listener = tf.TransformListener()
        rospy.sleep(1)
        pos, ori = listener.lookupTransform("panda_link0", "rgb_camera_link" , rospy.Time.now())

        t_mtx = np.array(tf.transformations.quaternion_matrix(ori))
        t_mtx[:3, 3] = np.array(pos)

        print(t_mtx)

    fps = 3
    config = load_config(subst)
    hoi_wrapper = HOI4ABOT(config)
    hoi_wrapper.config['fps'] = fps
    config = hoi_wrapper.get_config()


    pykinect.initialize_libraries()
    device_config = pykinect.default_configuration
    device_config.color_resolution = pykinect.K4A_COLOR_RESOLUTION_1080P
    device_config.depth_mode = pykinect.K4A_DEPTH_MODE_NFOV_2X2BINNED
    video_cap = pykinect.start_device(config=device_config)

    meta_info= {
        "video_count": 0,
        "video_path": "kinect_camera",
        "frame_count": 0,
        "frame_num": 0,
    }


    transform = HOIBOT_Transform(img_size=(224,224))
    frame0, transformed_depth_image, transformed_colored_depth_image, meta_info = get_new_kinect_frame(video_cap, meta_info, transform=transform)
    video_writer, config = prepare_viderowriter(fps, frame0, config)

    draw_hois = False
    do_hois=True
    text_drawn = ""
    if vis:
        cv2.namedWindow('Video', cv2.WINDOW_AUTOSIZE)
        #cv2.namedWindow('Depth', cv2.WINDOW_AUTOSIZE)

    print(f"============ Inference Start... ============")
    idx = 0
    tbar = tqdm(total=1.e+8)

    while(True):
        idx +=1
        frame0, transformed_depth_image, transformed_colored_depth_image, meta_info = get_new_kinect_frame(video_cap, meta_info, transform=transform)
        if vis and not config["VERBOSE"]["show_detections"]:
            cv2.imshow("Video", frame0)
            if cv2.waitKey(1) == ord('q'):
                break

        meta_info["original_shape"] = frame0.shape
        frame_annotated, hoi_res, boxes_xyxy, names = hoi_wrapper.process_step(idx, frame0, meta_info, do_hois=do_hois)

        if hoi_res !=None:
            """
            dict_keys(['video_name', 'frame_id', 'bboxes', 'pred_labels', 'confidences', 'pair_idxes', 'interaction_distribution', 'ids'])
                - video_name: video id: path to save
                - frame_id: frame num of last 
                - bboxes: list(bbox). List of N elemenets (being N the number of elements tracked/detected in the image. Each bbox is a list of [x,z,h,w] (>need to verify<))
                - pred_labels: list(int). Labels referring to the class detected regarding the HOIs 
                - confidences: list(float). Confidence score for the N elements tracked
                - pair_idxes: list(tuple). List of all Human-Object Pairs. For N detections, where N-1 are objects and 1 detected person, we have P = N-1 HOI pairs
                - interaction_distribution: list(list): Px50. Score referring to the label of each p HOI. We have 50 HOI categories and P pairs.
                - ids: list(int). Tracking Ids
                - triplets_scores: list(tuple): Px3: Each Pair Top-K has a tuple that indicates: score, idx_pair, interaction_pred 

            """
            text_drawn = ""
            for future_num, triplet_score in hoi_res["triplets_scores"].items():
                text_drawn = f" -------{future_num} ------- \n"
                for score, idx_pair, interaction_pred in triplet_score:
                    interaction_name = hoi_wrapper.config["CLASSES"]["interaction_classes"][interaction_pred]
                    if (interaction_name in ["towards"] and score>0.2) or interaction_pred not in hoi_wrapper.spatial_class_idxes:
                        subj_idx = hoi_res["pair_idxes"][idx_pair][0]
                        subj_cls = hoi_res["pred_labels"][subj_idx]
                        subj_name = hoi_wrapper.config["CLASSES"]["object_classes"][subj_cls]
                        subj_id = hoi_res["ids"][subj_idx]
                        obj_idx = hoi_res["pair_idxes"][idx_pair][1]
                        obj_cls = hoi_res["pred_labels"][obj_idx]
                        obj_name = hoi_wrapper.config["CLASSES"]["object_classes"][obj_cls]
                        obj_id = hoi_res["ids"][obj_idx]
                        interaction_name = hoi_wrapper.config["CLASSES"]["interaction_classes"][interaction_pred]
                        text_drawn += f"FUTURE {future_num}: {subj_name}{subj_id} - {interaction_name} - {obj_name}{obj_id}: {score:.2f} \n"

                    if(interaction_name == "hold" and score>0.35 and do_robot):
                        print("Doing the robot")
                        do_robot=False
                        publish_points(boxes_xyxy, names, batch["transformed_depth_image"], dataset, pub, t_mtx, frame_annotated)

            tbar.update(idx)
            tbar.set_description(
                f"{text_drawn}"
            )

        if draw_hois:
            # adding filled rectangle on each frame
            h, w, _ = frame_annotated.shape
            cv2.rectangle(frame_annotated, (int(w*0.6), int(h*0.)), (int(w*1), int(h*0.3)), (255, 255, 255), -1)
            for i, line in enumerate(text_drawn.split('\n')):
                cv2.putText(
                    frame_annotated,
                    line,
                    (int(w*0.61), int(h*0.03*(i+1))),
                    0,
                    0.8,
                    (0,0,0),
                    thickness=1,
                    lineType=cv2.LINE_AA,
                )

        if vis and config["VERBOSE"]["show_detections"]:
            #frame_annotated = cv2.flip(frame_annotated, 1)
            cv2.imshow("Video", frame_annotated)
            #cv2.imshow("Depth", batch["transformed_colored_depth_image"])
            if cv2.waitKey(1) == ord('q'):
                break
        video_writer.write(frame_annotated)

    # release video writer
    video_writer.release()
    # store detections and gazes
    with config["PATHS"]["result_file"].open("w") as f:
        json.dump(hoi_wrapper.result_list, f)
    with config["PATHS"]["hoi_file"].open("w") as f:
        f.writelines(hoi_wrapper.hoi_list)


if __name__ == "__main__":
    from configs.paths import project_path
    source = project_path + "/output"
    future = 0
    subst = {"source": source, "future": future, "print":True, "hoi_thres":0.35}
    main(subst)
