import os
from tqdm import tqdm

import cv2
import numpy as np
import supervision as sv

import torch
import torchvision

from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
from moviepy import ImageSequenceClip
from multiprocessing import Pool

def __init__():
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # GroundingDINO config and checkpoint
    GROUNDING_DINO_CONFIG_PATH = "./projects/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
    GROUNDING_DINO_CHECKPOINT_PATH = "./projects/Grounded-Segment-Anything/groundingdino_swint_ogc.pth"
    # Segment-Anything checkpoint
    SAM_ENCODER_VERSION = "vit_h"
    SAM_CHECKPOINT_PATH = "./projects/Grounded-Segment-Anything/sam_vit_h_4b8939.pth"

    # Building GroundingDINO inference model
    global grounding_dino_model, sam_predictor
    grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

    # Building SAM Model and SAM Predictor
    sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
    sam.to(device=DEVICE)
    sam_predictor = SamPredictor(sam)


# Prompting SAM with detected boxes
def segment_TikTok_image(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
    # Find the largest bounding box
    largest_box_idx = np.argmax([
        (x2 - x1) * (y2 - y1) 
        for x1, y1, x2, y2 in xyxy
    ])
    largest_box = xyxy[largest_box_idx]

    # Segment the largest bounding box
    sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    masks, scores, logits = sam_predictor.predict(
        box=largest_box,
        multimask_output=True
    )
    best_mask = masks[np.argmax(scores)]

    # Save the mask as a binary image
    binary_mask = (best_mask * 255).astype(np.uint8)
    return binary_mask
    

def process_video(source_video_path, target_video_dir, classes, box_threshold, text_threshold, nms_threshold):
    cap = cv2.VideoCapture(source_video_path)

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    fg_masks = []
    fg_images = []

    for _ in tqdm(range(frame_count), desc="Processing video frames", unit="frame"):
        ret, frame = cap.read()
        if not ret:
            break

        # detect objects
        detections = grounding_dino_model.predict_with_classes(frame, classes, box_threshold, text_threshold)
        if len(detections) == 0:
            print(f"No objects detected in frame {_}")
            continue

        # NMS post process
        nms_idx = torchvision.ops.nms(
            torch.from_numpy(detections.xyxy), 
            torch.from_numpy(detections.confidence), 
            nms_threshold
        ).numpy().tolist()

        detections.xyxy = detections.xyxy[nms_idx]
        detections.confidence = detections.confidence[nms_idx]
        detections.class_id = detections.class_id[nms_idx]

        # convert detections to masks
        mask = segment_TikTok_image(
            sam_predictor=sam_predictor,
            image=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB),
            xyxy=detections.xyxy
        )

        binary_mask = (mask > 0).astype(np.uint8)
        binary_mask_bgr = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2BGR)
        foreground = cv2.bitwise_and(frame, binary_mask_bgr * 255)
        foreground = cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB)

        fg_masks.append(mask)
        fg_images.append(foreground)

    cap.release()

    video_name = os.path.basename(source_video_path)
    fg_mask_video_path = os.path.join(target_video_dir, "fg_mask", video_name)
    if not os.path.exists(os.path.dirname(fg_mask_video_path)):
        os.makedirs(os.path.dirname(fg_mask_video_path))
    fg_mask_clip = ImageSequenceClip([cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) for mask in fg_masks], fps=fps)
    fg_mask_clip.write_videofile(fg_mask_video_path, codec="libx264")

    fg_image_video_path = os.path.join(target_video_dir, "fg", video_name)
    if not os.path.exists(os.path.dirname(fg_image_video_path)):
        os.makedirs(os.path.dirname(fg_image_video_path))
    fg_image_clip = ImageSequenceClip(fg_images, fps=fps)
    fg_image_clip.write_videofile(fg_image_video_path, codec="libx264")


if __name__ == "__main__":
    __init__()
    
    source_video_path = ""
    target_video_dir = ""
    if not os.path.exists(target_video_dir):
        os.makedirs(target_video_dir)

    CLASSES = ["person"]
    BOX_THRESHOLD = 0.25
    TEXT_THRESHOLD = 0.25
    NMS_THRESHOLD = 0.8

    process_video(source_video_path, target_video_dir, CLASSES, BOX_THRESHOLD, TEXT_THRESHOLD, NMS_THRESHOLD)
    def process_single_video(video_path):
        try:
            process_video(video_path, target_video_dir, CLASSES, BOX_THRESHOLD, TEXT_THRESHOLD, NMS_THRESHOLD)
        except Exception as e:
            print(f"Error processing video {video_path}: {e}")