import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor
import numpy as np

from typing import Callable, Optional, Tuple, Union
from einops import reduce,rearrange

from modules.hoi4abot.hoibot.modules.transformer.modules_attn.Attention import Attention
from einops import rearrange, reduce, repeat

def make_2tuple(x):
    if isinstance(x, tuple):
        assert len(x) == 2
        return x

    assert isinstance(x, int)
    return (x, x)

class MOABlender(nn.Module):
    """
    MOA blender
    https://arxiv.org/pdf/2304.08114v1.pdf
    """
    def __init__(
        self,
        embed_dim,
        img_size: Union[int, Tuple[int, int]] = 224,
        patch_size: Union[int, Tuple[int, int]] = 16,
        flatten_embedding: bool = True,
        normalize: bool=True,
        attn_drop=0.2, num_heads = 8, proj_drop = 0, qk_scale = False, qkv_bias=True,device="cuda:0",
        eps = 0.0
    ) -> None:
        super().__init__()
        # assert blender_type in ["pooling", "weighted"], f"Patch Blender type {blender_type} is not implemented."
        # self.blender_type= blender_type
        image_HW = make_2tuple(img_size)
        patch_HW = make_2tuple(patch_size)
        patch_grid_size = (
            image_HW[0] // patch_HW[0],
            image_HW[1] // patch_HW[1],
        )

        self.img_size = image_HW
        self.patch_size = patch_HW
        self.patches_resolution = patch_grid_size
        self.num_patches = patch_grid_size[0] * patch_grid_size[1]
        self.width_patch = int(image_HW[0] / patch_HW[0])
        self.avg_pool = nn.AvgPool2d(kernel_size=self.width_patch, stride=self.width_patch)

        self.flatten_embedding = flatten_embedding
        self.normalize = normalize

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.num_heads = num_heads
        self.atten_layer = Attention(embed_dim, attn_drop=attn_drop,
            num_heads=num_heads, proj_drop=proj_drop, qk_scale=qk_scale, qkv_bias=qkv_bias
        )

        self.device = device
        self.eps = eps

        # self.as_binary = blender_type =="pooling"

    def forward(self, patch_tokens, bin_masks: Tensor) -> Tensor:
        """
        :param bboxes: [S, H, W]
                -> S = B *T* NumObjects
                -> Per each batch, objects and frame it contains a binary masks [H,W] where 1 indicates that the object is in that pixel
        :return:
        """
        B, H, W = bin_masks.shape
        patch_H, patch_W = self.patch_size

        assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
        assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"

        mask = self.avg_pool(bin_masks)

        if self.flatten_embedding:
            mask = mask.reshape(B, -1)  # B H W C

        patch_tokens = self._prepend_token(patch_tokens)
        atten_mask = self._create_atten_mask(mask)

        res = self.atten_layer(patch_tokens, src_mask=atten_mask)

        return res[:, 0]

    def _prepend_token(self, tokens):
        B = tokens.shape[0]
        prepend_token = repeat(self.cls_token, "b t e -> (repeat b) t e", repeat=B)

        return torch.cat([prepend_token, tokens], dim=1)

    def _create_atten_mask(self, mask):
        B, T = mask.shape
        atten_mask = torch.ones((B, T+1, T+1), device=self.device)
        atten_mask[:, 0, 1:] = mask + self.eps
        
        return repeat(torch.log(atten_mask), "b t1 t2 -> b n t1 t2", n=self.num_heads)

    def create_patch_batch(self, bboxes):
        B = bboxes.shape[0]
        bboxes = bboxes.int()
        bin_mask = torch.zeros(B, *self.img_size)
        for i, bbox in enumerate(bboxes):
            bin_mask[i, bbox[..., 1]:bbox[..., 3], bbox[..., 0]:bbox[..., 2]] = 1
        return bin_mask

    def create_patch(self, bbox):
        """
        :param bboxes: [4] . Bbox[XYXY]: (top left courner, top right corner)
        :return: patch
        """
        bin_mask = torch.zeros(*self.img_size)
        bin_mask[bbox[..., 1]:bbox[..., 3], bbox[..., 0]:bbox[..., 2]] = 1
        return bin_mask

    def visualize_patches(self, patch, show_=False):
        """
        :param patch: tensor of [NumPatches, NumPatches]
        :return:
        """
        plt.figure()
        plt.pcolormesh(patch, edgecolors="k", linewidth=1)
        ax = plt.gca()
        ax.set_aspect("equal")
        ax.invert_yaxis()
        plt.colorbar()
        for y in range(self.patch_size[0]):
            for x in range(self.patch_size[1]):
                plt.text(x + 0.5, y + 0.5, "%.2f" % patch[y, x].item(),
                         horizontalalignment='center',
                         verticalalignment='center',
                         )
        if show_:
            plt.show()
            plt.close("all")


    def visualize_patched_image(self, img, show_=False):
        patch_size = self.patch_size[0]
        fig, axs = plt.subplots(patch_size, patch_size, figsize = (2,2), sharex=True, sharey=True)
        for i in range(patch_size):
            for j in range(patch_size):
                axs[i,j].matshow(img[(i*self.width_patch): ((i+1)*self.width_patch), (j*self.width_patch): ((j+1)*self.width_patch)],  vmin=0, vmax=1)
                axs[i,j].set_xticks([])
                axs[i,j].set_yticks([])
        if show_:
            plt.show()
            plt.close("all")


    def visualize_image(self, img,  show_=False):
        plt.matshow(img)
        if show_:
            plt.show()
            plt.close("all")

    def adapt_bbox(self, bbox, from_size=(640, 640), to_size = (224,224)):
        #bbox_normalized = self.xyxy2xywhn(bbox, w=from_size[0], h=from_size[1], clip=False, eps=0.0)
        #bbox_new_size =  self.xywhn2xyxy(bbox_normalized, w=to_size[0], h=to_size[1], padw=0, padh=0)

        x_scale = to_size[0] / from_size[0]
        y_scale = to_size[1] / from_size[1]
        y = bbox.clone() if isinstance(bbox, torch.Tensor) else np.copy(bbox)
        y[..., 0] = x_scale * (bbox[..., 0])
        y[..., 1] = y_scale * (bbox[..., 1])
        y[..., 2] = x_scale * (bbox[..., 2])
        y[..., 3] = y_scale * (bbox[..., 3])
        return None, y.int()


    def xywhn2xyxy(self, bbox, w=640, h=640, padw=0, padh=0):
        # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
        y = bbox.clone() if isinstance(bbox, torch.Tensor) else np.copy(bbox)
        y[..., 0] = w * (bbox[..., 0] - bbox[..., 2] / 2) + padw  # top left x
        y[..., 1] = h * (bbox[..., 1] - bbox[..., 3] / 2) + padh  # top left y
        y[..., 2] = w * (bbox[..., 0] + bbox[..., 2] / 2) + padw  # bottom right x
        y[..., 3] = h * (bbox[..., 1] + bbox[..., 3] / 2) + padh  # bottom right y
        return y

    def xyxy2xywhn(self, bbox, w=640, h=640, clip=False, eps=0.0):
        # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
        if clip:
            self.clip_coords(bbox, (h - eps, w - eps))  # warning: inplace clip
        y = bbox.clone() if isinstance(bbox, torch.Tensor) else np.copy(bbox)
        y[..., 0] = ((bbox[..., 0] + bbox[..., 2]) / 2) / w  # x center
        y[..., 1] = ((bbox[..., 1] + bbox[..., 3]) / 2) / h  # y center
        y[..., 2] = (bbox[..., 2] - bbox[..., 0]) / w  # width
        y[..., 3] = (bbox[..., 3] - bbox[..., 1]) / h  # height
        return y

    def clip_coords(self, boxes, shape):
        # Clip bounding xyxy bounding boxes to image shape (height, width)
        if isinstance(boxes, torch.Tensor):  # faster individually
            boxes[:, 0].clamp_(0, shape[1])  # x1
            boxes[:, 1].clamp_(0, shape[0])  # y1
            boxes[:, 2].clamp_(0, shape[1])  # x2
            boxes[:, 3].clamp_(0, shape[0])  # y2
        else:  # np.array (faster grouped)
            boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1])  # x1, x2
            boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0])  # y1, y2


if __name__ == '__main__':
    img_size = 224
    patch_size = 16
    B = 4
    E = 384
    visualize_img= True
    visualize_patch = True
    bboxes = torch.tensor([[0, 10, 40, 50], [20, 30, 60, 70], [10, 20, 30, 40], [100, 110, 110, 120]])

    bin_masks = torch.zeros(B, img_size, img_size)
    patch_blender = MOABlender(E, img_size=img_size,patch_size= patch_size, normalize=True, flatten_embedding=True, device='cpu', eps=1e-6)
    for i, box in enumerate(bboxes):
        bin_masks[i] = patch_blender.create_patch(box)

    patch_tokens = torch.rand(4, patch_size**2, E)

    # print("Image dimensions", imgs.shape)
    obj_tokens = patch_blender(patch_tokens, bin_masks)

    print("Final_tokens", obj_tokens.shape)

    total_params = sum(p.numel() for p in patch_blender.parameters() if p.requires_grad)
    print("[{}] - {:.2f}M".format("Patch Blender", total_params / 10 ** 6))

    # print("Patch dimensions", patches.shape)

    # if visualize_img:
    #     patch_blender.visualize_image(imgs[0])
    #     patch_blender.visualize_patched_image(imgs[0])
    # if visualize_patch:
    #     patch_blender.visualize_patches(patches[0])
    # plt.show()




