import cv2
import os
from glob import glob

root = './outputs/'
exp_name = 'test8'

vis_track = True
vis_det = True

save_dir = os.path.join(root, 'vis_video', exp_name)


class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 
                       'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 
                       'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 
                       'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 
                       'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 
                       'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 
                       'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 
                       'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

imgs = glob('/Datasets/egotracks/v2/egotracks/clips_frames/dets_v2/test/3c9331fc-90eb-4069-875c-bfc9a97443af/raw/*jpg')
tracks = glob(os.path.join(root, exp_name, 'tracks/*.txt'))
dets = glob(os.path.join(root, exp_name, 'dets/*.txt'))
tracks.sort()
imgs.sort()
dets.sort()


if vis_track:
    track_save = os.path.join(save_dir, 'track')
    if not os.path.exists(track_save):
        os.makedirs(track_save)
    # read tracks and images, draw bouding box from track to image and save them as a .MP4 video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(os.path.join(track_save, 'track_output.mp4'), fourcc, 10, (1920, 1080))

    for i in range(len(tracks)):
        if i > 99:
            break
        img = cv2.imread(imgs[i])
        with open(tracks[i], 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip().split(',')
                x1, y1, x2, y2 = map(int, line[2:6])
                track_id = line[1]
                track_cls = line[-3]
                # class_name = class_names[int(track_cls)]
                class_name = int(track_cls)
                score = line[-4]
                # different color for different track id
                color = (int(track_id)*10%255, int(track_id)*20%255, int(track_id)*30%255)
                cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                # write track id and class
                cv2.putText(img, f'{track_id}_{class_name}_score{score}', (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
        out.write(img)
        # save image with bounding box
        cv2.imwrite(os.path.join(track_save, f'track_{i}.jpg'), img)
        print(f'{i+1}/{len(tracks)}')
    out.release()

if vis_det:
    det_save = os.path.join(save_dir, 'det')
    if not os.path.exists(det_save):
        os.makedirs(det_save)
    # read tracks and images, draw bouding box from track to image and save them as a video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(os.path.join(det_save, 'det_output.mp4'), fourcc, 10, (1920, 1080))

    for i in range(len(dets)):
        img = cv2.imread(imgs[i])
        with open(dets[i], 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip().split(',')
                x1, y1, x2, y2 = map(int, map(float, line[2:6]))
                track_id = line[1]
                track_cls = line[-3]
                class_name = class_names[int(track_cls)]
                score = line[-4]
                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
                # write track id and class
                cv2.putText(img, f'{track_id}_{class_name}_score{score}', (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        out.write(img)
        # save image with bounding box
        cv2.imwrite(os.path.join(det_save, f'det_{i}.jpg'), img)
        print(f'{i+1}/{len(dets)}')
    out.release()    

