#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import cv2
import torch
from pathlib import Path
import json
from tqdm import tqdm
from collections import deque, defaultdict

from modules.hoi4abot.ModelWrapper import Model_Wrapper
from demo.common.model.modules.objecttracker import ObjectTracker

from demo.common.utils.image_processing import convert_annotation_frame_to_video
from demo.common.utils.model_utils import (
    bbox_pair_generation,
    concat_separated_head,
    construct_sliding_window,
    generate_sliding_window_mask,
)

from demo.common.utils.metrics_utils import generate_triplets_scores

class HOI4ABOT():
    def __init__(self, config, use_yolo=False):
        self.config = self.prepare_config(config)
        self.device = config["PARAMS"]["device"]
        self.object_tracker, self.hoi_model = self.load_modules(config)
        self.ishydra = self.hoi_model.modelname == "HYDRA"
        self.isboth = self.hoi_model.modelname == "BOTH"


        # results for one video
        self.detection_dict = defaultdict(list)
        self.gaze_list = []
        self.hoi_list = []
        self.result_list = []
        # FIFO queues for sliding window
        self.frames_queue = deque(maxlen=self.config["PARAMS"]["sttran_sliding_window"])
        self.frame_ids_queue = deque(maxlen=self.config["PARAMS"]["sttran_sliding_window"])
        # iteration over the video, get object traces and human gazes
        # NOTE only store the detections and gazes in keyframes into files, but the video contains all frames
        self.hx_memory = {}

    def get_config(self):
        return self.config

    def prepare_config(self, config):

        object_classes = config['CLASSES']["object_classes"]
        action_class_idxes = config['CLASSES']["action_class_idxes"]
        spatial_class_idxes = config['CLASSES']["spatial_class_idxes"]
        interaction_classes = config['CLASSES']["interaction_classes"]


        num_object_classes = len(object_classes)
        num_interaction_classes = len(interaction_classes)
        num_spatial_classes = len(spatial_class_idxes)
        num_action_classes = len(action_class_idxes)
        num_interaction_classes_loss = num_interaction_classes
        loss_type_dict = {"spatial_head": "bce", "action_head": "bce"}
        separate_head_num = [num_spatial_classes, -1]
        separate_head_name = ["spatial_head", "action_head"]
        class_idxes_dict = {"spatial_head": spatial_class_idxes, "action_head": action_class_idxes}
        loss_gt_dict = {"spatial_head": "spatial_gt", "action_head": "action_gt"}

        config["TEST"] = {
            "num_object_classes": num_object_classes,
            "num_interaction_classes": num_interaction_classes,
            "num_action_classes": num_action_classes,
            "num_spatial_classes": num_spatial_classes,
            "num_interaction_classes_loss": num_interaction_classes_loss,
            "loss_type_dict": loss_type_dict,
            "class_idxes_dict": class_idxes_dict,
            "loss_gt_dict": loss_gt_dict,
        }
        return config

    def load_modules(self, config):
        object_tracker = ObjectTracker(device="cuda:0", weights_path ="../weights/yolo/yolov8x.pt", tracker_name = "botsort.yaml", classes_use=config["CLASSES"]["object_classes"], hoi2coco_mapper=config["CLASSES"]["hoi2coco_mapper"]  )

        hoi_model = Model_Wrapper(config)
        hoi_model.to(self.device)
        if hoi_model.modelname not in ["HYDRA", "BOTH"]:
            incompatibles = hoi_model.load_state_dict(torch.load(config['PATHS']["model_path"]))
            print(f"HOIBOT AND YOLO loaded. Incompatible keys {incompatibles}")
        else:
            print(f"GET READY! HYDRA IS LOADED")


        hoi_model.eval()
        return object_tracker, hoi_model


    def generate_sliding_windows(self, meta_info):
        # generate sliding window
        self.frames_queue.append(meta_info["additional"])
        self.frame_ids_queue.append(meta_info["frame_count"])
        frames = torch.stack(list(self.frames_queue)).to(self.device)
        det_bboxes = self.detection_dict["bboxes"][-self.config["PARAMS"]["sttran_sliding_window"]:]
        det_ids = self.detection_dict["ids"][-self.config["PARAMS"]["sttran_sliding_window"]:]
        det_labels = self.detection_dict["labels"][-self.config["PARAMS"]["sttran_sliding_window"]:]
        det_confidences = self.detection_dict["confidences"][-self.config["PARAMS"]["sttran_sliding_window"]:]
        bboxes, ids, pred_labels, confidences = convert_annotation_frame_to_video(
            det_bboxes, det_ids, det_labels, det_confidences
        )
        return bboxes, ids, pred_labels, confidences, frames


    def fill_entry_inference(self, frames, detected, meta_info):
        # from meta_info
        original_shape = meta_info["original_shape"]

        # Results from object detection/ground-truth and pair generation
        pred_labels = torch.LongTensor(detected["pred_labels"])
        bboxes = torch.Tensor(detected["bboxes"])
        pair_idxes = torch.LongTensor(detected["pair_idxes"])
        ids = detected["ids"]
        confidences = torch.Tensor(detected["confidences"])
        im_idxes = torch.LongTensor(detected["im_idxes"])
        pair_human_ids = torch.zeros_like(im_idxes)
        pair_object_ids = torch.zeros_like(im_idxes)

        # frames
        if type(frames) is list:
            frames = torch.stack(frames, axis=0)

        if len(im_idxes) > 0:
            for pair_idx, (im_idx, pair) in enumerate(zip(im_idxes, pair_idxes)):
                human_id = ids[pair[0]]
                object_id = ids[pair[1]]
                pair_human_ids[pair_idx] = human_id
                pair_object_ids[pair_idx] = object_id

        # Extract bboxes
        # scale up/down bounding boxes, from original frame to transformed frame
        img_size = frames.shape[-1]
        bbox_normalized, bbox_new_size = self.hoi_model.model.blender.patch_blender.adapt_bbox(bboxes[:, 1:], from_size=(original_shape[1], original_shape[0]), to_size=(img_size, img_size))
        bboxes[:, 1:] = bbox_new_size
        binary_masks = self.hoi_model.model.blender.patch_blender.create_patch_batch(bboxes[:, 1:])

        entry = {
            "pred_labels": pred_labels,  # labels from object detector
            "bboxes": bboxes,  # bboxes from object detector
            "ids": ids,  # object id from object detector
            "confidences": confidences,  # score from object detector
            "pair_idxes": pair_idxes,  # subject-object pairs, generated after object detector
            "pair_human_ids": pair_human_ids,  # human id in each pair, from ground-truth
            "pair_object_ids": pair_object_ids,  # object id in each pair, from ground-truth
            "im_idxes": im_idxes,  # each pair belongs to which frame index
            "bboxes_normalized": bboxes,  # ground-truth bboxes, no need for training
            "binary_masks":binary_masks,
            "frames": frames,
        }
        return entry

    def predict_hois(self, idx, meta_info):
        bboxes, ids, pred_labels, confidences, frames =  self.generate_sliding_windows(meta_info)
        if len(bboxes) == 0:
            # no detection
            pair_idxes = []
            im_idxes = []
            print("[Error?] No objects detected")
            return None, None
        else:
            # Generate human-object pairs
            pair_idxes, im_idxes = bbox_pair_generation(bboxes, pred_labels, 0)
        if len(im_idxes)==0:
            print("[Error?] No bbox pair generation")
            return None, None
        detected = {
            "bboxes": bboxes,
            "pred_labels": pred_labels,
            "ids": ids,
            "confidences": confidences,
            "pair_idxes": pair_idxes,
            "im_idxes": im_idxes,
        }
        # fill the entry with detections
        entry = self.fill_entry_inference(
            frames,
            detected,
            meta_info
        )

        windows = construct_sliding_window(entry)
        entry, windows, windows_out, out_im_idxes, _ = generate_sliding_window_mask(entry, windows, None, "pair")

        # only do model forward if any valid window exists
        if len(windows) > 0:
            # everything to GPU
            entry["pair_idxes"] = entry["pair_idxes"].to(self.device)
            entry["pred_labels"] = entry["pred_labels"].to(self.device)
            entry["windows"] = windows.to(self.device)
            entry["windows_out"] = windows_out.to(self.device)
            entry["binary_masks"] = entry["binary_masks"].to(self.device)

            entry["exist_mask"] = entry["exist_mask"].to(self.device) if "exist_mask" in entry else []
            entry["change_mask"] = entry["change_mask"].to(self.device) if "change_mask" in entry else []
            entry["frames"] = entry["frames"].to(self.device)
            entry["binary_masks"] = entry["binary_masks"].to(self.device)
            entry["bboxes"] = entry["bboxes"].to(self.device)
            entry["bboxes"] = entry["bboxes"].to(self.device)
            entry["out_im_idxes"] = torch.stack(out_im_idxes, axis=0).to(self.device)


            # forward
            entry = self.hoi_model(entry)
            if self.ishydra or self.isboth:
                future_nums = self.hoi_model.model.future_nums
                interaction_distributions = {}
                for f in future_nums:
                    if self.ishydra:
                        head_anticipation = f"future_num_{f}"
                    else:
                        head_anticipation = f
                    # sigmoid or softmax
                    for head_name in self.config["TEST"]["loss_type_dict"].keys():
                        if self.config["TEST"]["loss_type_dict"][head_name] == "ce":
                            entry[head_anticipation][head_name] = torch.softmax(entry[head_anticipation][head_name], dim=-1)
                        else:
                            entry[head_anticipation][head_name] = torch.sigmoid(entry[head_anticipation][head_name])
                    # in inference, length prediction may != length gt
                    # len_preds = len(interactions_gt)
                    len_preds = len(entry[head_anticipation][list(self.config["TEST"]["loss_type_dict"].keys())[0]])
                    interaction_distributions[head_anticipation] = concat_separated_head(
                        entry[head_anticipation], len_preds, self.config["TEST"]["loss_type_dict"],
                        self.config["TEST"]["class_idxes_dict"], self.device, True
                    )
            else:
                # sigmoid or softmax
                for head_name in self.config["TEST"]["loss_type_dict"].keys():
                    if self.config["TEST"]["loss_type_dict"][head_name] == "ce":
                        entry[head_name] = torch.softmax(entry[head_name], dim=-1)
                    else:
                        entry[head_name] = torch.sigmoid(entry[head_name])
                # in inference, length prediction may != length gt
                # len_preds = len(interactions_gt)
                len_preds = len(entry[list(self.config["TEST"]["loss_type_dict"].keys())[0]])
                interaction_distribution = concat_separated_head(
                    entry, len_preds, self.config["TEST"]["loss_type_dict"], self.config["TEST"]["class_idxes_dict"], self.device, True
                )

        # process output
        frame_ids = list(self.frame_ids_queue)
        # window-wise result entry
        if len(out_im_idxes)>0:
            out_im_idx = out_im_idxes[0]
            window_anno = {
                "video_name": self.config["PATHS"]["video_name"],  # video name
                "frame_id": frame_ids[out_im_idx],  # this frame id
            }
            if self.config["PARAMS"]["sampling_mode"] == "anticipation":
                if idx + self.config["PARAMS"]["future"] >= self.config["PARAMS"]["frame_num"]:
                    window_anno["future_frame_id"] = ""
                else:
                    window_anno["future_frame_id"] = f"{idx + self.config['VIDEO']['fps'] * self.config['PARAMS']['future']:06d}"

        window_prediction = {
            "bboxes": [],
            "pred_labels": [],
            "confidences": [],
            "pair_idxes": [],
            "interaction_distribution": [],
        }
        # case 1, nothing detected in the full clip, result all []
        if len(entry["bboxes"]) == 0 or len(out_im_idxes) ==0:
            pass
        else:
            det_out_idxes = entry["bboxes"][:, 0] == out_im_idx
            # case 2, nothing detected in this window, result all []
            if not det_out_idxes.any():
                pass
            else:
                # something detected, fill object detection results
                # NOTE det_idx_offset is the first bbox index in this window
                det_idx_offset = det_out_idxes.nonzero(as_tuple=True)[0][0]
                # bboxes = entry["bboxes"][det_out_idxes, 1:]
                # pred_labels = entry["pred_labels"][det_out_idxes]
                # confidences = entry["confidences"][det_out_idxes]
                # ids = entry["ids"][det_out_idxes]
                window_prediction["bboxes"] = entry["bboxes"][det_out_idxes, 1:].detach().cpu().numpy().tolist()
                window_prediction["pred_labels"] = entry["pred_labels"][det_out_idxes].detach().cpu().numpy().tolist()
                window_prediction["confidences"] = entry["confidences"][det_out_idxes].detach().cpu().numpy().tolist()
                window_prediction["ids"] = np.array(entry["ids"])[det_out_idxes.detach().cpu().numpy()].tolist()

                pair_out_idxes = entry["im_idxes"] == out_im_idx
                # case 3, no human-object pair detected (no human or no object), pair_idxes and distribution []
                if not pair_out_idxes.any():
                    pass
                else:
                    # case 4, have everything
                    pair_idxes = entry["pair_idxes"][pair_out_idxes] - det_idx_offset
                    # handle interaction distributions
                    window_prediction["pair_idxes"] = pair_idxes.cpu().numpy().tolist()
                    window_prediction["interaction_distribution"] = {}
                    if self.ishydra or self.isboth:
                        for k, v in interaction_distributions.items():
                            window_prediction["interaction_distribution"][k] = v.detach().cpu().numpy().tolist()
                    else:
                        window_prediction["interaction_distribution"][f"future_num_{self.config['PARAMS']['future']}"] = interaction_distribution.detach().cpu().numpy().tolist()
        window_result = {**window_anno, **window_prediction}
        self.result_list.append(window_result)
        # print HOIs, only considering interaction scores
        triplets_scores = {}
        for future_num, interaction_distribution in window_prediction["interaction_distribution"].items():
            triplets_scores[future_num] = generate_triplets_scores(
                window_result["pair_idxes"],
                [1.0] * len(window_result["confidences"]),
                interaction_distribution,
                multiply=True,
                top_k=100,
                thres=self.config["PARAMS"]["hoi_thres"],
            )
        s_hois = ""
        s_hois += f"Frame {idx}:"
        for future_num, triplet_score in triplets_scores.items():
            for score, idx_pair, interaction_pred in triplet_score:
                subj_idx = window_result["pair_idxes"][idx_pair][0]
                subj_cls = window_result["pred_labels"][subj_idx]
                subj_name = self.config["CLASSES"]["object_classes"][subj_cls]
                subj_id = window_result["ids"][subj_idx]
                obj_idx = window_result["pair_idxes"][idx_pair][1]
                obj_cls = window_result["pred_labels"][obj_idx]
                obj_name = self.config["CLASSES"]["object_classes"][obj_cls]
                obj_id = window_result["ids"][obj_idx]
                interaction_name = self.config["CLASSES"]["interaction_classes"][interaction_pred]
                s_hois += f"FUTURE {future_num}: {subj_name}{subj_id} - {interaction_name} - {obj_name}{obj_id}: {score} | "
            self.hoi_list.append(s_hois)
        return window_result, triplets_scores

    def process_step(self, idx, frame0, meta_info, do_hois=True):
        # object tracking
        frame_annotated, results = self.object_tracker(frame0.copy(), draw=True)
        bboxes, ids, labels, names, confs = results

        # store result for every second
        if idx % self.config['VIDEO']['fps'] == 0:
            bboxes = [bbox.tolist() for bbox in bboxes]
            self.detection_dict["bboxes"].append(bboxes)
            self.detection_dict["ids"].append(ids.tolist())
            self.detection_dict["labels"].append(labels.tolist())
            self.detection_dict["confidences"].append(confs.tolist())
            self.detection_dict["frame_ids"].append(meta_info["frame_count"])

        # predict HOIs every second
        if idx % self.config['VIDEO']['fps'] == 0 and do_hois:
            # print("[HOI Prediction]")
            window_results, triplets_scores = self.predict_hois(idx, meta_info)
            if type(window_results)==dict:
                window_results["triplets_scores"] = triplets_scores
            return frame_annotated, window_results, bboxes, names
        return frame_annotated, None, bboxes, names

    def process_video(self, dataset):
        t = tqdm(enumerate(iter(dataset)), total=self.config["PARAMS"]["frame_num"])
        # t = iter(dataset)
        for idx, batch in t:
            frame, frame0, _, _, meta_info = batch
            meta_info["original_shape"] = frame0.shape


