import matplotlib.pyplot as plt

from modules.hoi4abot.hoibot.modules.prompt_encoder.prompt_transforms import ResizeLongestSide
from modules.hoi4abot.hoibot.modules.prompt_encoder.prompt_encoder import PromptEncoder
from modules.hoi4abot.hoibot.modules.patch_blender.patch_blender import PatchBlender

from ultralytics.vit import SAM
import torch
import numpy as np
import torch.nn as nn
from einops import rearrange

class Bbox_Wrapper(nn.Module):
    def __init__(self, box_encoder_type, embedding_dimension, input_image_size, patch_size, device):
        super().__init__()
        self.box_encoder_type = box_encoder_type
        if box_encoder_type == "SAM":
            bbox_embedder = SAMPromptEncoder(embedding_dimension, "sam_b.pt", device, img_size=input_image_size)
        elif box_encoder_type == "Box":
            bbox_embedder = Bbox_Encoder(embedding_dimension=embedding_dimension, device=device,
                                              img_size=input_image_size)
        elif box_encoder_type == "Linear":
            bbox_embedder = LinearPromptEncoder(embedding_dimension)
        elif box_encoder_type == "BinaryMask":
            bbox_embedder = BinaryMaskEncoder(patch_size=patch_size, embedding_dimension=embedding_dimension,
                                                   img_size=input_image_size)
        elif box_encoder_type == "None":
            bbox_embedder = nn.Linear(embedding_dimension, 2 * embedding_dimension)
        elif box_encoder_type =="Convolution":
            bbox_embedder = BoxConv_Encoder(embedding_dimension=embedding_dimension, device=device, img_size=(224, 224))

        else:
            raise RuntimeError("Unrecognized bbox embedding type")
        self.bbox_embedder =bbox_embedder
        self.unify_positions = nn.Linear(2*embedding_dimension, embedding_dimension)


    def isused(self):
        return self.box_encoder_type !=None

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} {}] - {:.2f}M".format("Bbox_Encoder", self.box_encoder_type, total_params / 10 ** 6))
    def extend_projection(self, cls_tokens, patch_tokens):
        if self.box_encoder_type == "None":
            cls_tokens = self.bbox_embedder(cls_tokens)
            patch_tokens = self.bbox_embedder(patch_tokens)
        return cls_tokens, patch_tokens
    def forward(self, batch):
        """
        :param batch: among others ... {"bbox_objects": [B, T, NumObjects, 4], "bbox_humans": [B, T, NumHumans, 4] }
            - bbox_objects -> [B, T, NumObjects, 4]
                -> Per each batch, it contains the T bounding boxes of all (NumObjects), BoundingBoxes defined as Bbox[XYXY]: (top left courner, top right corner)
            - bbox_humans -> [B, T, NumHumans, 4]
                -> Per each batch, it contains the T bounding boxes of all (NumHumans), BoundingBoxes defined as Bbox[XYXY]: (top left courner, top right corner)
        :return:
        """
        if self.box_encoder_type in ["SAM", "Box", "Linear"]:
            bboxes = batch["bboxes"][:, 1:]
            B, _ = bboxes.shape
            point_labels = batch["point_labels"] if "point_labels" in batch else torch.ones(B)

            # print("bboxes", bboxes.shape)
            bboxes = self.bbox_embedder(point_coords=None, point_labels=point_labels, boxes=bboxes, mask_input=None)
            # print("bboxes", bboxes.shape)
            bboxes = rearrange(bboxes, "b s e -> b (s e)")

            bboxes = self.unify_positions(bboxes)
        elif self.box_encoder_type in ["BinaryMask"]:
            bboxes = self.bbox_embedder(batch["binary_masks"])

        else:
            print(f"Error. Box encoder tpye {self.box_encoder_type} is not implemented")
        return bboxes


class BoxConv_Encoder(nn.Module):
    def __init__(self, embedding_dimension=384, device="cuda:0", img_size=(224, 224)):
        super().__init__()
        img_size = img_size[0]
        self.conv_spatial = nn.Sequential(
            nn.Conv2d(2, img_size // 2, kernel_size=7, stride=2, padding=3, bias=True),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(img_size // 2, momentum=0.01),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(img_size // 2, img_size, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(img_size, momentum=0.01),
            nn.AdaptiveAvgPool2d(1),
        )
    def forward(self, batch):
        binary_masks = batch["binary_masks"]
        binary_masks = binary_masks[:, None]
        binary_masks =  self.conv_spatial(binary_masks)
        return binary_masks.reshape(-1, )
class Bbox_Encoder(nn.Module):
    def __init__(self, embedding_dimension= 384, device ="cuda:0", img_size=(224,224)):
        super().__init__()
        self.device = device
        self.original_size = img_size
        self.transform = ResizeLongestSide(max(self.original_size))
        self.prompt_encoder = PromptEncoder(embed_dim=embedding_dimension, input_image_size=img_size)
        self.prompt_encoder.to(self.device)

    def forward(self, point_coords=None, point_labels=None, boxes=None, mask_input=None):
        point_coords, point_labels, boxes = self.to_prompt_input(point_coords=point_coords, point_labels=point_labels, box=boxes)
        if point_coords is not None:
            points = (point_coords, point_labels)
        else:
            points = None

        # Embed prompts
        embeddings = self.prompt_encoder(
            points=points,
            boxes=boxes,
            masks=mask_input,
        )
        return embeddings


    def to_prompt_input(self, point_coords=None, point_labels=np.array([1]), box=None):
        """
        :param point_coords: (np.ndarray or None): A Nx2 array of point prompts to the
                model. Each point is in (X,Y) in pixels.
        :param point_labels: (np.ndarray or None): A length N array of labels for the
                point prompts. 1 indicates a foreground point and 0 indicates a
                background point.
        :param box: (np.ndarray or None): A length 4 array given a box prompt to the
                model, in XYXY format.
        :return: (point_coords, point_labels, boxes)
            point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
            model. Each point is in (X,Y) in pixels.

            point_labels (torch.Tensor or None): A BxN array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a
            background point.

            boxes (np.ndarray or None): A Bx4 array given a box prompt to the
            model, in XYXY format.
        """
        # Transform input prompts
        coords_torch, labels_torch, box_torch = None, None, None
        if point_coords is not None:
            assert (point_labels is not None), "point_labels must be supplied if point_coords is supplied."
            coords_torch = self.transform.apply_coords_torch(point_coords, self.original_size)
            #coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
            labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
            coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
        if box is not None:
            box_torch = self.transform.apply_boxes_torch(box, self.original_size)
            #box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
            #box_torch = box_torch[None, :]
        return coords_torch, labels_torch, box_torch

class BinaryMaskEncoder(nn.Module):
    def __init__(self, patch_size, embedding_dimension= 384, device ="cuda:0", img_size=(224,224)):
        super().__init__()
        self.device = device
        self.original_size = img_size
        self.num_patches = patch_size**2
        self.patch_blender = PatchBlender(img_size=img_size, patch_size=patch_size, flatten_embedding=True)
        self.project_bboxes = nn.Linear(self.num_patches, embedding_dimension)
    def forward(self, binary_masks, patchified_binary_mask=None):
        patchified_binary_mask = self.patch_blender(binary_masks) if patchified_binary_mask is None else patchified_binary_mask
        embeddings = self.project_bboxes(patchified_binary_mask)
        return embeddings

class SAMPromptEncoder(nn.Module):
    def __init__(self, embedding_dimension, sam_model="sam_b.pt", device ="cuda:0", img_size=(1024,1024)):
        super().__init__()
        self.device = device
        sam_model = SAM(sam_model).model
        self.original_size = sam_model.image_encoder.img_size if img_size is None else img_size
        # print(self.original_size)
        # print(max(self.original_size))
        self.transform = ResizeLongestSide(max(self.original_size))
        self.prompt_encoder = sam_model.prompt_encoder
        self.prompt_encoder.to(self.device)

        self.projection_layer = nn.Linear(256, embedding_dimension)

    def forward(self, point_coords=None, point_labels=None, boxes=None, mask_input=None):
        point_coords, point_labels, boxes = self.to_prompt_input(point_coords=point_coords, point_labels=point_labels, box=boxes)
        if point_coords is not None:
            points = (point_coords, point_labels)
        else:
            points = None

        # Embed prompts
        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=points,
            boxes=boxes,
            masks=mask_input,
        )

        sparse_embeddings = self.projection_layer(sparse_embeddings)

        return sparse_embeddings


    def to_prompt_input(self, point_coords=None, point_labels=np.array([1]), box=None):
        """
        :param point_coords: (np.ndarray or None): A Nx2 array of point prompts to the
                model. Each point is in (X,Y) in pixels.
        :param point_labels: (np.ndarray or None): A length N array of labels for the
                point prompts. 1 indicates a foreground point and 0 indicates a
                background point.
        :param box: (np.ndarray or None): A length 4 array given a box prompt to the
                model, in XYXY format.
        :return: (point_coords, point_labels, boxes)
            point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
            model. Each point is in (X,Y) in pixels.

            point_labels (torch.Tensor or None): A BxN array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a
            background point.

            boxes (np.ndarray or None): A Bx4 array given a box prompt to the
            model, in XYXY format.
        """
        # Transform input prompts
        coords_torch, labels_torch, box_torch = None, None, None
        if point_coords is not None:
            assert (point_labels is not None), "point_labels must be supplied if point_coords is supplied."
            point_coords = self.transform.apply_coords(point_coords, self.original_size)
            coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
            labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
            coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
        if box is not None:
            box = self.transform.apply_boxes(box, self.original_size)
            box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
            box_torch = box_torch[None, :]
        return coords_torch, labels_torch, box_torch

class LinearPromptEncoder(nn.Module):
    def __init__(self, embedding_dimension):
        super().__init__()

        self.projection_layer = nn.Linear(2, embedding_dimension)

    def forward(self, point_coords=None, point_labels=None, boxes=None, mask_input=None):
        boxes = boxes.reshape([boxes.shape[0], 2, 2])
        sparse_embeddings = self.projection_layer(boxes)

        return sparse_embeddings


    def to_prompt_input(self, point_coords=None, point_labels=np.array([1]), box=None):
        """
        :param point_coords: (np.ndarray or None): A Nx2 array of point prompts to the
                model. Each point is in (X,Y) in pixels.
        :param point_labels: (np.ndarray or None): A length N array of labels for the
                point prompts. 1 indicates a foreground point and 0 indicates a
                background point.
        :param box: (np.ndarray or None): A length 4 array given a box prompt to the
                model, in XYXY format.
        :return: (point_coords, point_labels, boxes)
            point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
            model. Each point is in (X,Y) in pixels.

            point_labels (torch.Tensor or None): A BxN array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a
            background point.

            boxes (np.ndarray or None): A Bx4 array given a box prompt to the
            model, in XYXY format.
        """
        # Transform input prompts
        coords_torch, labels_torch, box_torch = None, None, None
        if point_coords is not None:
            assert (point_labels is not None), "point_labels must be supplied if point_coords is supplied."
            point_coords = self.transform.apply_coords(point_coords, self.original_size)
            coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
            labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
            coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
        if box is not None:
            box = self.transform.apply_boxes(box, self.original_size)
            box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
            box_torch = box_torch[None, :]
        return coords_torch, labels_torch, box_torch


if __name__ == '__main__':
    from hoi.model.hoibot.modules.utils import read_image, show_image

    image = read_image()
    point_coords = np.array([[300, 55]])
    point_labels = np.array([1])
    boxes = np.array([[170,0,425,250]])
    show_image(image, point_coords, point_labels,boxes)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    h, w = image.shape[:2]
    prompt_encoder = SAMPromptEncoder("sam_b.pt", device, img_size=(h, w))
    sparse_embeddings, dense_embeddings = prompt_encoder(point_coords=None, point_labels=point_labels, boxes=boxes, mask_input=None)
    print(sparse_embeddings.shape, dense_embeddings.shape)
    # sparse_embeddings [B, 2 or 3 (if bbox is not none and point_coords is not None), 256]
    # dense_embeddings [B,256, 64, 64]
    for feat in sparse_embeddings[0]:
        feat = feat.reshape(16, 16)
        figure = plt.figure()
        axes = plt.matshow(feat.detach().cpu().numpy(), interpolation ='nearest')
        figure.colorbar(axes)
        plt.show()

