import sys

sys.path.insert(0, "../modules/object_tracking/yolov5")
sys.path.insert(0, "..")

import numpy as np
import shelve
import json
from pathlib import Path

from tqdm import tqdm
import torch
from torch.utils.data import DataLoader

from modules.object_tracking import ObjectTracking

from common.dataset.vidhoi_dataset import VidHOIDataset
from common.dataset.data_io import FrameDatasetLoader
from common.dataset.transforms import YOLOv5Transform
from configs.paths import *

device = "cuda:0" if torch.cuda.is_available() else "cpu"

output_folder = Path(dataset_dir) / "VidHOI_detection"
output_folder.mkdir(exist_ok=True)
yolov5_model_size = "yolov5l"

# tracking_mode = "key"  # only track objects in key frames
tracking_mode = "all"  # track objects in all frames, then only keep the key frames

object_tracking_module = ObjectTracking(
    yolo_weights_path="../weights/yolov5/vidor_" + yolov5_model_size + ".pt",
    deep_sort_model_dir="../weights/deep_sort/",
    config_path="../configs/head_and_track/object_tracking.yaml",
    device=device,
)

img_size = 640
yolov5_stride = object_tracking_module.yolov5_stride

vidhoi_val_dataset = VidHOIDataset(
    annotations_file=Path(dataset_dir) / "annotations" / "val_frame_annots.json",
    frames_dir=Path(dataset_dir) / "images",
    transform=YOLOv5Transform(img_size, yolov5_stride),
    min_length=1,
    max_length=999999,
    max_human_num=999999,
    annotation_mode="clip",
    train_ratio=0,
)
vidhoi_val_dataset.eval()
vidhoi_val_dataloader = DataLoader(vidhoi_val_dataset, batch_size=None, shuffle=False)

if tracking_mode == "key":
    # dict for all videos
    all_detections = {}
    t = tqdm(vidhoi_val_dataloader)
    # for each video
    # don't need annotation here
    for frames, _, meta_info in t:
        video_name = meta_info['video_name']
        t.set_description(f"{video_name}")
        t.refresh()
        original_frames = meta_info["original_frames"]
        frame_ids = meta_info["frame_ids"]
        clip_len = len(frames) - 1
        # object tracking init
        object_tracking_module.clear()
        object_tracking_module.warmup(frames[0].to(device), original_frames[0])
        # entry for one video
        clip_detections = {
            "bboxes": [],
            "ids": [],
            "labels": [],
            "confidences": [],
            "frame_ids": [],
        }
        # for each frame, do detection and tracking
        for im_idx, (frame, original_frame, frame_id) in enumerate(zip(frames, original_frames, frame_ids)):
            t.set_postfix_str(f"{im_idx}/{clip_len}: {frame_id}")
            t.refresh()
            bboxes, ids, labels, _, confidences, _ = object_tracking_module.track_one(frame.to(device), original_frame, draw=False)
            # frame-based format, NOTE need to convert to [im_idx, x1, y1, x2, y2] later
            bboxes = [bbox.tolist() for bbox in bboxes]
            clip_detections["bboxes"].append(bboxes)
            clip_detections["ids"].append(ids)
            clip_detections["labels"].append(labels)
            clip_detections["confidences"].append(confidences)
            clip_detections["frame_ids"].append(frame_id)
        all_detections[video_name] = clip_detections
else:
    print("Skip, not key frame mode")

if tracking_mode == "all":
    total_frame_num = 0
    # dict for all videos
    all_detections = {}
    # for each video, load all frames, only keep the detections in key frames
    # don't need annotation here
    t = tqdm(range(len(vidhoi_val_dataset)))
    for video_idx in t:
        video_name = vidhoi_val_dataset.video_name_list[video_idx]
        frame_ids = vidhoi_val_dataset.frame_ids_list[video_idx]
        t.set_description(f"{video_name}")
        t.refresh()
        # entry for one video
        clip_detections = {
            "bboxes": [],
            "ids": [],
            "labels": [],
            "confidences": [],
            "frame_ids": [],
        }
        # load all frames
        video_frame_path = Path(dataset_dir) / "images" / video_name
        video_loader = FrameDatasetLoader(video_frame_path, YOLOv5Transform(img_size, yolov5_stride))
        for frame_idx, (frame, frame0, _, _, meta_info) in enumerate(video_loader):
            total_frame_num += 1
            if frame_idx == 0:
                # object tracking init
                object_tracking_module.clear()
                object_tracking_module.warmup(frame.to(device), frame0)
                
            frame_id = str(meta_info["frame_path"])[-10:-4]
            clip_len = meta_info["frame_num"] - 1
            t.set_postfix_str(f"{frame_idx}/{clip_len}: {frame_id}")
            t.refresh()
            bboxes, ids, labels, _, confidences, _ = object_tracking_module.track_one(frame.to(device), frame0, draw=False)
            # only store the detections in key frame set
            if frame_id in frame_ids:
                # frame-based format, NOTE need to convert to [im_idx, x1, y1, x2, y2] later
                bboxes = [bbox.tolist() for bbox in bboxes]
                clip_detections["bboxes"].append(bboxes)
                clip_detections["ids"].append(ids)
                clip_detections["labels"].append(labels)
                clip_detections["confidences"].append(confidences)
                clip_detections["frame_ids"].append(frame_id)
        all_detections[video_name] = clip_detections
    print(f"\nTotally {total_frame_num} frames")

else:
    print("Skip, not all frame mode")

# Save to file
filename = output_folder / ("val_trace_" + yolov5_model_size + "_deepsort.json")
out_str = json.dumps(all_detections)
with filename.open("w") as out_file:
    out_file.write(out_str)