import json, os
from tqdm import tqdm
import torch
from demo.common.config.config import load_config
from demo.common.dataset.transforms import HOIBOT_Transform
from demo.common.model.HOI4ABOT import HOI4ABOT
import cv2

import numpy as np


def prepare_viderowriter(fps, frame0, config):
    config["VIDEO"] = {"fps": fps, "height": frame0.shape[1], "width": frame0.shape[0]}
    # output video writer
    fourcc = "mp4v"
    fps, w, h = round(fps), frame0.shape[1], frame0.shape[0]  # not handle decimal fps
    video_writer = cv2.VideoWriter(str(config["PATHS"]["result_video_file"]), cv2.VideoWriter_fourcc(*fourcc), fps,
                                   (w, h))
    return video_writer, config


def get_new_camera_frame(video_cap, meta_info, transform=None):
    success, frame0 = video_cap.read()
    while not success:
        success, frame0 = video_cap.read()

    if transform is not None:
        frame_add = transform(frame0)
        meta_info["additional"] = frame_add
    meta_info.update({"frame_count": meta_info["frame_count"] + 1, "frame_num": meta_info["frame_num"] + 1})
    return frame0, None, None, meta_info


@torch.no_grad()
def main(subst=None, vis=True, ros=False, anticipation_head=1):
    if ros:
        import rospy
        from std_msgs.msg import Float32
        from pour_sm.msg import Interaction, InteractionArray
        pub_conf = rospy.Publisher("/interactions", InteractionArray, queue_size=1)
        rospy.init_node("anticipation_node")

    fps = 3
    config = load_config(subst)
    hoi_wrapper = HOI4ABOT(config)
    hoi_wrapper.config['fps'] = fps
    config = hoi_wrapper.get_config()

    video_cap = cv2.VideoCapture(0)
    video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
    video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)

    meta_info = {
        "video_count": 0,
        "video_path": "kinect_camera",
        "frame_count": 0,
        "frame_num": 0,
    }

    print(f"============ Inference Start... ============")

    transform = HOIBOT_Transform(img_size=(224, 224))
    frame0, transformed_depth_image, transformed_colored_depth_image, meta_info = get_new_camera_frame(video_cap,
                                                                                                       meta_info,
                                                                                                       transform=transform)
    video_writer, config = prepare_viderowriter(fps, frame0, config)

    draw_hois = False
    do_hois = True
    text_drawn = ""
    if vis:
        cv2.namedWindow('Video', cv2.WINDOW_AUTOSIZE)
        # cv2.namedWindow('Depth', cv2.WINDOW_AUTOSIZE)

    idx = 0
    tbar = tqdm(total=1.e+8)
    while (True):
        idx += 1
        frame0, transformed_depth_image, transformed_colored_depth_image, meta_info = get_new_camera_frame(video_cap,
                                                                                                           meta_info,
                                                                                                           transform=transform)
        if vis and not config["VERBOSE"]["show_detections"]:
            cv2.imshow("Video", frame0)
            if cv2.waitKey(1) == ord('q'):
                break

        meta_info["original_shape"] = frame0.shape
        frame_annotated, hoi_res, boxes_xyxy, names = hoi_wrapper.process_step(idx, frame0, meta_info, do_hois=do_hois)

        if hoi_res != None:
            """
            dict_keys(['video_name', 'frame_id', 'bboxes', 'pred_labels', 'confidences', 'pair_idxes', 'interaction_distribution', 'ids'])
                - video_name: video id: path to save
                - frame_id: frame num of last 
                - bboxes: list(bbox). List of N elemenets (being N the number of elements tracked/detected in the image. Each bbox is a list of [x,z,h,w] (>need to verify<))
                - pred_labels: list(int). Labels referring to the class detected regarding the HOIs 
                - confidences: list(float). Confidence score for the N elements tracked
                - pair_idxes: list(tuple). List of all Human-Object Pairs. For N detections, where N-1 are objects and 1 detected person, we have P = N-1 HOI pairs
                - interaction_distribution: list(list): Px50. Score referring to the label of each p HOI. We have 50 HOI categories and P pairs.
                - ids: list(int). Tracking Ids
                - triplets_scores: list(tuple): Px3: Each Pair Top-K has a tuple that indicates: score, idx_pair, interaction_pred 

            """
            text_drawn = {}
            toprint = {}
            interactions = {}
            for future_num, triplet_score in hoi_res["triplets_scores"].items():
                text_drawn[future_num] = ""
                text_drawn[future_num] += f" -------{future_num} ------- \n"
                toprint[future_num] = False
                is_detection = future_num in ["future_num_0", "detection"]
                is_anticipation = future_num in ["anticipation",
                                                 f"future_num_{anticipation_head}"] and future_num != "future_num_0"  # , "future_num_3", "future_num_5"]
                for score, idx_pair, interaction_pred in triplet_score:
                    interaction_name = hoi_wrapper.config["CLASSES"]["interaction_classes"][interaction_pred]
                    # if (interaction_name in ["use"] and score>0.2) or interaction_pred not in hoi_wrapper.config["CLASSES"]["spatial_class_idxes"]:
                    if score > 0.2:
                        subj_idx = hoi_res["pair_idxes"][idx_pair][0]
                        subj_cls = hoi_res["pred_labels"][subj_idx]
                        subj_name = hoi_wrapper.config["CLASSES"]["object_classes"][subj_cls]
                        subj_id = hoi_res["ids"][subj_idx]
                        obj_idx = hoi_res["pair_idxes"][idx_pair][1]
                        obj_cls = hoi_res["pred_labels"][obj_idx]
                        obj_name = hoi_wrapper.config["CLASSES"]["object_classes"][obj_cls]
                        obj_id = hoi_res["ids"][obj_idx]
                        interaction_name = hoi_wrapper.config["CLASSES"]["interaction_classes"][interaction_pred]
                        if "cup" in obj_name:
                            text_drawn[
                                future_num] += f"{subj_name}{subj_id} - {interaction_name} - {obj_name}{obj_id}: {score:.2f} | "
                            toprint[future_num] = True

                        if (interaction_name == "next_to" and obj_name == "cup" and is_anticipation):
                            interactions[interaction_name] = score

                        if (interaction_name == "hold" and obj_name == "cup" and is_detection):
                            interactions[interaction_name] = score
                text_drawn[future_num] += f"\n"

            if "hold" not in interactions.keys():
                interactions["hold"] = 0
            if "next_to" not in interactions.keys():
                interactions["next_to"] = 0

            if ros:
                msg = InteractionArray()
                for k, v in interactions.items():
                    i = Interaction()
                    i.name = k
                    i.conf = v
                    msg.interactions.append(i)
                pub_conf.publish(msg)

            text_drawn_all = ""
            update_all = False
            for fut, text in text_drawn.items():
                if toprint[fut]:
                    update_all = True
                    text_drawn_all += text
            tbar.update(idx)
            if update_all:
                tbar.set_description(
                    f"{text_drawn_all}"
                )
            else:
                tbar.set_description(
                    ""
                )

        if draw_hois:
            # adding filled rectangle on each frame
            h, w, _ = frame_annotated.shape
            cv2.rectangle(frame_annotated, (int(w * 0.6), int(h * 0.)), (int(w * 1), int(h * 0.3)), (255, 255, 255), -1)
            for i, line in enumerate(text_drawn.split('\n')):
                cv2.putText(
                    frame_annotated,
                    line,
                    (int(w * 0.61), int(h * 0.03 * (i + 1))),
                    0,
                    0.8,
                    (0, 0, 0),
                    thickness=1,
                    lineType=cv2.LINE_AA,
                )

        if vis and config["VERBOSE"]["show_detections"]:
            # frame_annotated = cv2.flip(frame_annotated, 1)
            cv2.imshow("Video", frame_annotated)
            # cv2.imshow("Depth", batch["transformed_colored_depth_image"])
            if cv2.waitKey(1) == ord('q'):
                break
        video_writer.write(frame_annotated)

    # release video writer
    video_writer.release()
    # store detections and gazes
    with config["PATHS"]["result_file"].open("w") as f:
        json.dump(hoi_wrapper.result_list, f)
    with config["PATHS"]["hoi_file"].open("w") as f:
        f.writelines(hoi_wrapper.hoi_list)


if __name__ == "__main__":
    from configs.paths import project_path
    source = project_path
    future = 0
    subst = {"source": source, "future": future, "print": True, "hoi_thres": 0.35}
    main(subst)