import os
import cv2
import numpy as np
import torch
import torchvision
import supervision as sv
from torchvision.ops import box_convert
from groundingdino.util.inference import load_image, predict, annotate, load_model
from segment_anything import sam_model_registry, SamPredictor
import time

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Paths and configurations
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GROUNDING_DINO_CONFIG_PATH = "PATH_TO_GDINO_CONFIG"
GROUNDING_DINO_CHECKPOINT_PATH = "PATH_TO_GDINO_WEIGHTS"
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"

BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
NMS_THRESHOLD = 0.8
CLASSES = ["car.", "truck.", "pedestrian.", "cyclist.", "tram.", "van."]  # Modify classes as needed

# Initialize models
grounding_dino_model = load_model(GROUNDING_DINO_CONFIG_PATH, GROUNDING_DINO_CHECKPOINT_PATH)
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)


def phrases2classes(phrases, classes):
	class_ids = []
	for phrase in phrases:
		class_id = next((i for i, cls in enumerate(classes) if cls.startswith(phrase.lower().split()[0])), -1)
		class_ids.append(class_id)
	return np.array(class_ids)


def post_process_result(source_h, source_w, boxes, logits):
	boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
	xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
	confidence = logits.numpy()
	return sv.Detections(xyxy=xyxy, confidence=confidence)


def segment(sam_predictor, image, xyxy):
	sam_predictor.set_image(image)
	result_masks = []
	for box in xyxy:
		masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True)
		index = np.argmax(scores)
		result_masks.append(masks[index])
	return np.array(result_masks)


def process_images(image_dir, output_dir):
	os.makedirs(output_dir, exist_ok=True)

	for image_name in sorted(os.listdir(image_dir)):
		start = time.time()
		if not image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
			continue
		# image_name = "000009_01.png"

		image_path = os.path.join(image_dir, image_name)
		print(f"Processing {image_path}...")

		# Load image
		image_source, image = load_image(image_path)
		height, width = image_source.shape[:2]

		# Grounded DINO predictions
		boxes, logits, phrases = predict(
			model=grounding_dino_model,
			image=image,
			caption=" ".join(CLASSES),
			box_threshold=BOX_THRESHOLD,
			text_threshold=TEXT_THRESHOLD
		)

		# Post-process results
		detections = post_process_result(source_h=height, source_w=width, boxes=boxes, logits=logits)
		detections.class_id = phrases2classes(phrases=phrases, classes=CLASSES)

		# Non-Maximum Suppression (NMS)
		nms_idx = torchvision.ops.nms(
			torch.from_numpy(detections.xyxy),
			torch.from_numpy(detections.confidence),
			NMS_THRESHOLD
		).numpy().tolist()

		detections.xyxy = detections.xyxy[nms_idx]
		detections.confidence = detections.confidence[nms_idx]
		detections.class_id = detections.class_id[nms_idx]

		# Generate masks using SAM
		detections.mask = segment(sam_predictor=sam_predictor, image=image_source, xyxy=detections.xyxy)

		# Annotate and save the image
		box_annotator = sv.BoxAnnotator()
		mask_annotator = sv.MaskAnnotator()
		labels = [f"{CLASSES[class_id]} {confidence:0.2f}" for class_id, confidence in
		          zip(detections.class_id, detections.confidence)]

		annotated_image = mask_annotator.annotate(scene=image_source, detections=detections)
		annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
		annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)

		annotated_image_path = os.path.join(output_dir, f"annotated_{image_name}")
		cv2.imwrite(annotated_image_path, annotated_image)
		print(f"Annotated image saved at: {annotated_image_path}")

		# Save results as .npy file
		npy_data = {
			"image": image_name,
			"boxes": detections.xyxy,
			"logits": detections.confidence,
			"labels": [CLASSES[class_id] for class_id in detections.class_id],
			"masks": detections.mask,
			"width": width,
			"height": height
		}
		npy_file_path = os.path.join(output_dir, f"{os.path.splitext(image_name)[0]}.npy")
		np.save(npy_file_path, npy_data)
		print(f"Results saved at: {npy_file_path}")
		end = time.time()
		print(f"Elapsed time: {end - start:.4f} seconds")


if __name__ == "__main__":
	image_directory = "PATH_TO_DATA"  # Path to input image directory
	output_directory = "PATH_TO_OUTPUT"  # Path to output directory
	process_images(next_directory, output_directory)
