import numpy as np
import torch
import torch.nn as nn
from typing import Any, Optional, Tuple, Type
from nets.model_zoo.ReCOT.backbone import build_position_encoding
import math

class Prompt_Encoder(nn.Module):
    def __init__(self,
               embed_dim:int,
                 input_image_size: Tuple[int, int] = (512, 512)):
        super(Prompt_Encoder, self,).__init__( )
        self.embed_dim = embed_dim
        self.pe_layer = PositionEmbeddingSine(self.embed_dim // 2)
        self.input_image_size = input_image_size

        self.num_point_embeddings: int = 2
        point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
        self.point_embeddings = nn.ModuleList(point_embeddings)


    def _embed_points(self, points:torch.Tensor, pad: bool)-> torch.Tensor:
        if pad:
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            points = torch.cat([points, padding_point], dim=1)
        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)

        point_embedding[:, 0,  :] += self.point_embeddings[0].weight
        return point_embedding

    def _get_batch_size(
            self,
            points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    ) -> int:
        """
        Gets the batch size of the output given the batch size of the input prompts.
        """
        if points is not None:
            return points.shape[0]
        else:
            return 1

    def _get_device(self) -> torch.device:
        return self.point_embeddings[0].weight.device

    def forward(
            self,
            points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        points = points.unsqueeze(1)
        bs = self._get_batch_size(points)
        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
        boxes = None
        if points is not None:
            point_embeddings = self._embed_points(points, pad=(boxes is None))
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)

        return sparse_embeddings

class PositionEmbeddingRandom(nn.Module):
    """
    Positional encoding using random spatial frequencies.
    """

    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        super().__init__()
        if scale is None or scale <= 0.0:
            scale = 1.0
        self.register_buffer(
            "positional_encoding_gaussian_matrix",
            scale * torch.randn((2, num_pos_feats)),
        )

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        coords = 2 * coords - 1
        coords = coords @ self.positional_encoding_gaussian_matrix
        coords = 2 * np.pi * coords
        # outputs d_1 x ... x d_n x C shape
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size
        device: Any = self.positional_encoding_gaussian_matrix.device
        grid = torch.ones((h, w), device=device, dtype=torch.float32)
        y_embed = grid.cumsum(dim=0) - 0.5
        x_embed = grid.cumsum(dim=1) - 0.5
        y_embed = y_embed / h
        x_embed = x_embed / w

        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
        return pe.permute(2, 0, 1)  # C x H x W

    def forward_with_coords(
        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
    ) -> torch.Tensor:
        """Positionally encode points that are not normalized to [0,1]."""
        coords = coords_input.clone()
        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
        coords[:, :, 1] = coords[:, :, 1] / image_size[0]

        return self._pe_encoding(coords.to(torch.float))  # B x N x C


class PositionEmbeddingSine(nn.Module):
    """
    Standard Sine-Cosine Positional Encoding as used in Transformer models.
    """

    def __init__(self, num_pos_feats: int = 64, temperature: float = 10000, normalize: bool = True,
                 scale: Optional[float] = None) -> None:
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size
        y_embed = torch.arange(h, dtype=torch.float32).unsqueeze(1).repeat(1, w)
        x_embed = torch.arange(w, dtype=torch.float32).unsqueeze(0).repeat(h, 1)

        if self.normalize:
            y_embed = y_embed / (h - 1) * self.scale
            x_embed = x_embed / (w - 1) * self.scale

        dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed.unsqueeze(-1) / dim_t
        pos_y = y_embed.unsqueeze(-1) / dim_t

        pos_x = torch.stack((pos_x.sin(), pos_x.cos()), dim=-1)
        pos_y = torch.stack((pos_y.sin(), pos_y.cos()), dim=-1)

        pos = torch.cat((pos_y, pos_x), dim=-1).permute(2, 0, 1)  # C x H x W
        return pos

    def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
        """Positionally encode arbitrary coordinates."""
        coords = coords_input.clone()
        # coords[..., 0] = coords[..., 0] / image_size[1] * self.scale
        # coords[..., 1] = coords[..., 1] / image_size[0] * self.scale

        coords[..., 0] = coords[..., 0]  * self.scale
        coords[..., 1] = coords[..., 1]  * self.scale

        dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats).to(coords.device)

        pos_x = coords[..., 0].unsqueeze(-1) / dim_t
        pos_y = coords[..., 1].unsqueeze(-1) / dim_t

        pos_x = torch.stack((pos_x.sin(), pos_x.cos()), dim=-1)
        pos_y = torch.stack((pos_y.sin(), pos_y.cos()), dim=-1)

        pos = torch.cat((pos_y, pos_x), dim=-1)  # B x N x C
        b, n0, n1, n2 = pos.size()
        pos = pos.view(b, n0, n1 * n2)
        return pos