from typing import Any
import cv2
from ultralytics import YOLO
import numpy as np
import torch

class ObjectTracker():
    def __init__(self, device="cuda:0", weights_path ="../weights/yolo/yolov8x.pt", tracker_name = "botsort.yaml",
                 classes_use =["person", "sofa", "table", "chair", "cup", "bottle", "backpack", "laptop", "dish", "fruits", "vegetables", "toy", "handbag"],
                 hoi2coco_mapper=None):
        self.model = YOLO(weights_path)
        self.model.to(device)
        self.device = device
        self.tracker = tracker_name
        self.categories = self.model.names
        self.classes = classes_use
        self.hoi2coco_mapper = hoi2coco_mapper
        self.coco2hoi_mapper = {}
        for hoi_class, list_coco_names in self.hoi2coco_mapper.items():
            for coco_name in list_coco_names:
                if coco_name in self.coco2hoi_mapper:
                    print(f"Error: overwriting for {coco_name} - {hoi_class}")
                self.coco2hoi_mapper[coco_name] = hoi_class

    
    def __call__(self, frame, draw=True):
        results = self.model.track(frame, tracker=self.tracker, persist=True, verbose=False)
        results = self.adapt_results(results, as_dict=False)
        results = self.map_classes_to_ids(results)
        if draw:
            return self.draw_frames_inf(frame, results), results
        return frame, results
        

    def draw_frame(self, frame, results):
        for id, data in results.items():
            box = data["xyxy"]
            cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
            cv2.putText(
                frame,
                f"{data['category']}: {id}",
                (box[0], box[1]),
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                (0, 0, 255),
                2,
            )
        return frame
    
    def draw_frames_inf(self, frame, results):
        boxes, ids, cls, names, confs = results
        for box, id, name in zip(boxes, ids, names):
            cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
            cv2.putText(
                frame,
                f"{name}: {id}",
                (box[0], box[1]),
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                (0, 0, 255),
                2,
            )
        return frame
    
    def map_classes_to_ids(self, results):
        boxes, ids, cls, coco_names, confs = results
        labels__ = []
        indices_used =[]
        hoi_names = []
        for i, obj_det_name in enumerate(coco_names):
            if obj_det_name in self.coco2hoi_mapper:
                hoiname = self.coco2hoi_mapper[obj_det_name]
                hoi_names.append(hoiname)
                labels__.append(self.classes.index(hoiname))
                indices_used.append(i)
        labels__ = torch.tensor(labels__)
        return boxes[indices_used], ids[indices_used], labels__, hoi_names, confs[indices_used]

    def adapt_results(self, results, as_dict=False):
        res = results[0]
        boxes = res.boxes.xyxy.cpu().numpy().astype(int)
        ids = res.boxes.id.cpu().numpy().astype(int)
        cls = res.boxes.cls
        names = [res.names[i.item()] for i in cls]
        confs = res.boxes.conf
        if as_dict:
            return {id_: {"xyxy": boxes[i], "category": names[i], "label": cls[i], "ids": ids[i], "confs": confs[i]} for i, id_ in enumerate(ids)}
        return boxes, ids, cls, names, confs


    def detect_and_track(self, frame, draw_in_frame=True):
        results = model.track(frame, tracker="botsort.yaml", persist=True, verbose=False)
        results = self.adapt_results(results, as_dict=True)

        if draw_in_frame:
            frame, results = self.draw_frame(frame, results)
        return frame, results

if __name__ == '__main__':
    cap = cv2.VideoCapture(0)
    model = ObjectTracker("weights/yolo/yolov8x.pt", tracker_name = "botsort.yaml")
    categories = model.names

    while True:
        ret, frame = cap.read()
        frame = cv2.flip(frame, 1)

        if not ret:
            break
        frame, results = detect_and_track(frame, draw_in_frame=True)
        cv2.imshow("frame", frame)
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break