import os 
import torch as th 



class GroundingNetInput_ImgTextEmb:
    def __init__(self):
        self.set = False 

    def prepare(self, batch):
        """
        batch should be the output from dataset.
        Please define here how to process the batch and prepare the 
        input only for the ground tokenizer. 
        """

        self.set = True

        masks=batch['masks']

        if batch['use_text_embedding'] and not batch['use_image_embedding']:
            positive_embeddings=batch["text_embeddings"]
        elif not batch['use_text_embedding'] and batch['use_image_embedding']:
            positive_embeddings=batch["cropped_image_embeddings"]
        elif batch['use_text_embedding'] and batch['use_image_embedding']:
            positive_embeddings=th.concat((batch["text_embeddings"], batch["cropped_image_embeddings"]), dim = -1)
        
        if batch['position_net_point_or_box'] == 'box':
            boxes=batch['boxes'] 
            centers=None 
        elif batch['position_net_point_or_box'] == 'point':
            centers=batch['centers'] 
            boxes=None 
        elif batch['position_net_point_or_box'] is None:
            centers = None
            boxes = None 

        self.batch, self.max_box, self.in_dim = positive_embeddings.shape
        self.device = positive_embeddings.device
        self.dtype = positive_embeddings.dtype
        self.point_or_box = batch['position_net_point_or_box'] 

        return {"boxes":boxes, "centers":centers, "masks":masks, "positive_embeddings":positive_embeddings}


    def get_null_input(self, batch=None, device=None, dtype=None):
        """
        Guidance for training (drop) or inference, 
        please define the null input for the grounding tokenizer 
        """

        assert self.set, "not set yet, cannot call this funcion"
        batch =  self.batch  if batch  is None else batch
        device = self.device if device is None else device
        dtype = self.dtype   if dtype  is None else dtype
        if self.point_or_box=="box":
            boxes = th.zeros(batch, self.max_box, 4,).type(dtype).to(device) 
            centers = None
        elif self.point_or_box=='point':
            centers = th.zeros(batch, self.max_box, 2,).type(dtype).to(device) 
            boxes = None 
        elif self.point_or_box is None:
            centers = None 
            boxes = None 
            
        masks = th.zeros(batch, self.max_box).type(dtype).to(device) 
        positive_embeddings = th.zeros(batch, self.max_box, self.in_dim).type(dtype).to(device) 

        return {"boxes":boxes, "centers":centers, "masks":masks, "positive_embeddings":positive_embeddings}

