import os
import sys
import argparse
from tqdm import tqdm
from typing import List
import supervision as sv
import csv
import json

import cv2
import numpy as np

import torch
import torchvision

sys.path.append("../submodules/sam/GroundingDINO")
from groundingdino.util.inference import Model
sys.path.append("../submodules/sam/segment_anything")
from segment_anything import sam_model_registry, SamPredictor

# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "../submodules/sam/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "../submodules/sam/groundingdino_swint_ogc.pth"

# Segment-Anything checkpoint
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "../submodules/sam/sam_vit_h_4b8939.pth"

BOX_THRESHOLD = 0.225
TEXT_THRESHOLD = 0.225
NMS_THRESHOLD = 0.8

SAMPLE_OBJECT_DICT = {
    "Natural_Environments": {
        "Water_Bodies": [
            "Pond", "Lake", "River", "Ocean", "Stream", "Marsh", "Swamp",
            "Creek", "Brook", "Wetland", "Estuary"
        ],
        "Landscapes": [
            "Forest", "Desert", "Grassland", "Mountain", "Hill", "Field",
            "Prairie", "Savannah", "Tundra", "Urban Park", "Orchard"
        ]
    },
    "Vegetation": {
        "Trees": [
            "Oak", "Pine", "Maple", "Birch", "Palm",
            "Willow", "Redwood", "Sequoia", "Cypress", "Fir"
        ],
        "Plants": [
            "Reed", "Bush", "Shrub", "Grass", "Flower",
            "Cattail", "Lily Pad", "Moss", "Fern", "Bamboo", "Vine"
        ]
    },
    "Human_Made_Structures": {
        "Water_Adjacent": [
            "Pier", "Dock", "Boat", "Lighthouse",
            "Jetty", "Ship", "Kayak", "Floating Market"
        ],
        "Land_Based": [
            "House", "Barn", "Fence", "Road", "Bridge",
            "Playground", "Campsite", "Windmill", "Railroad", "Pathway"
        ]
    },
    "Weather_Conditions": {
        "Sky_Conditions": [
            "Cloudy", "Sunny", "Rainy", "Stormy",
            "Overcast", "Misty", "Foggy", "Dusk", "Dawn"
        ],
        "Seasonal_Changes": [
            "Snow", "Frost", "Ice",
        ]
    },
    "Miscellaneous_Objects": {
        "Feeding_Sources": [
            "Fish", "Seeds", "Insects", "Fruit",
            "Nuts", "Berries", "Worms", "Algae"
        ],
        "Predators_and_Threats": [
            "Cat", "Dog", "Hawk", "Human",
            "Fox", "Raccoon", "Eagle", "Snake"
        ],
        "Natural_Phenomena_and_Objects": [
            "Wave", "Ice", "Rock", "Mud",
            "Boulder", "Sand", "Shell", "Coral", "Driftwood"
        ],
        "Human_Activities_and_Items": [
            "Fishing Gear", "Picnic Area", "Trash", "Firepit"
        ],
        "Other_Animals": [
            "Frog", "Turtle", "Fish", "Duck", "Goose", "Swan", "Beaver", "Otter",
            "Rabbit", "Squirrel", "Chipmunk", "Deer", "Bear", "Raccoon", "Fox",
            "Coyote", "Wolf", "Bobcat", "Mountain Lion", "Elk", "Moose", "Bison",
            "Horse", "Cow", "Sheep", "Goat", "Pig", "Chicken", "Turkey",
        ]
    }
}

SAMPLE_OBJECTS = ["Bird"] + [e for v in SAMPLE_OBJECT_DICT.values() for l in v.values() for e in l]


class ConceptExtractor:
    def __init__(self) -> None:
        pass

    def extract(self, dataset_path: str, save_dir: str = None):
        pass

class LLAVAConceptExtractor(ConceptExtractor):
    def __init__(self) -> None:
        super().__init__()

    def extract(self, dataset_path: str, save_dir: str = None):
        pass

class GroundingDinoSamExtractor:
    def __init__(self, object_list: List[str], device=None) -> None:
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
        self.sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
        self.sam.to(device=self.device)
        self.sam_predictor = SamPredictor(self.sam)
        self.classes = object_list
    
    def segment(self, image_path: str, save_dir: str = None, box_threshold: float = BOX_THRESHOLD, text_threshold: float = TEXT_THRESHOLD, nms_threshold: float = NMS_THRESHOLD):
        os.makedirs(save_dir, exist_ok=True)
        image_name = ".".join(image_path.split("/")[-1].split(".")[:-1])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # detect objects
        detections = self.grounding_dino_model.predict_with_classes(
            image=image,
            classes=self.classes,
            box_threshold=box_threshold,
            text_threshold=text_threshold
        )

        # annotate image with detections
        box_annotator = sv.BoxAnnotator()
        labels = [
            f"{self.classes[class_id]} {confidence:0.2f}" 
            for _, _, confidence, class_id, _, _ 
            in detections]
        annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

        if save_dir is not None:
            save_path = os.path.join(save_dir, f"{image_name}_annotated.jpg")
            cv2.imwrite(save_path, annotated_frame)

        # NMS post process
        nms_idx = torchvision.ops.nms(
            torch.from_numpy(detections.xyxy), 
            torch.from_numpy(detections.confidence), 
            nms_threshold
        ).numpy().tolist()
        detections.xyxy = detections.xyxy[nms_idx]
        detections.confidence = detections.confidence[nms_idx]
        detections.class_id = detections.class_id[nms_idx]

        # segment objects
        self.sam_predictor.set_image(image)
        result_masks = []
        for box in detections.xyxy:
            mask, scores, logits = self.sam_predictor.predict(
                box=box,
                multimask_output=False
            )
            index = np.argmax(scores)
            result_masks.append(mask[index])
        detections.mask = np.array(result_masks)
        
        # annotate image with detections
        box_annotator = sv.BoxAnnotator()
        mask_annotator = sv.MaskAnnotator()
        labels = [
            f"{self.classes[class_id]} {confidence:0.2f}" 
            for _, _, confidence, class_id, _, _ 
            in detections]
        annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
        annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

        if save_dir is not None:
            save_path = os.path.join(save_dir, f"{image_name}_segmented.jpg")
            cv2.imwrite(save_path, annotated_image)

    def extract(self, dataset_path: str, save_dir: str = None):
        pass
    
    def detect(self, image_path: str, box_threshold: float = BOX_THRESHOLD, text_threshold: float = TEXT_THRESHOLD) -> List[str]:
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # detect objects
        detections = self.grounding_dino_model.predict_with_classes(
            image=image,
            classes=self.classes,
            box_threshold=box_threshold,
            text_threshold=text_threshold
        )
        
        discovered_objects = list(set([self.classes[id] for id in detections.class_id if id != 0]))
        return discovered_objects
    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--mode", type=str, choices=["sam"], help="Mode to run the extractor")
    parser.add_argument("--dataset", type=str, choices=["indoor", "urbancars"], default=None, help="Dataset to extract concepts from")
    parser.add_argument("--image_dir", type=str, help="Directory containing images to label")
    parser.add_argument("--image_path", type=str, help="Path to one single image to label")
    parser.add_argument("--save_dir", type=str, help="Directory to save the labels")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    if args.image_dir:
        image_paths = [os.path.join(args.image_dir, image_name) for image_name in sorted(os.listdir(args.image_dir))]
    elif args.image_path:
        image_paths = [args.image_path]
    else:
        raise ValueError("Either image_dir or image_path must be provided.")

    if args.dataset is None:
        print("No dataset provided. Running in instance-based mode.")
    
    if args.mode == "sam" and args.dataset is None:
        extractor = GroundingDinoSamExtractor(object_list=SAMPLE_OBJECTS, device=args.device)
        for image_path in tqdm(image_paths):
            extractor.segment(image_path, save_dir=args.save_dir)

    if args.dataset == "urbancars":
        object_list = ["MISMATCH", "alley", "crosswalk", "city", "gas station", "garage", "driveway", "forest", "field", "sand", "fireplug", "stop", "sign", "parking meter", "traffic light", "cow", "horse", "sheep"]
        extractor = GroundingDinoSamExtractor(object_list=object_list, device=args.device)

        metadata_path = os.path.join(args.image_dir, "metadata.csv")
        if not os.path.exists(metadata_path):
            raise ValueError("metadata.csv must be provided for urbancars in args.image_dir.")

        extraction_results = []

        with open(metadata_path, "r") as f:
            reader = csv.reader(f)
            headers = next(reader)
            for filename, split, y, place in tqdm(reader):
                if split != "0": continue
                image_path = os.path.join(args.image_dir, filename)
                if not os.path.exists(image_path):
                    raise ValueError(f"Image {image_path} does not exist.")
                concepts = extractor.detect(image_path, box_threshold=0.225, text_threshold=0.3)
                extraction_results.append({
                    "filename": filename,
                    "split": split,
                    "y": y,
                    "place": place,
                    "concepts": concepts
                })
                
        extraction_results.sort(key=lambda x: x["filename"] if len(x["filename"].split('_')[0]) == 4 else "0" + x["filename"])

        if args.save_dir is None:
            raise ValueError("save_dir must be provided for urbancars.")
        
        save_path = os.path.join(args.save_dir, "urbancars_concepts.json")
        with open(save_path, "w") as f:
            json.dump(extraction_results, f)