class GazeResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth=2,
        dim_head=64,
        heads=8,
        num_latents=1,
        ff_mult=4,
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        GazeAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )
        #self.query_transform= CustomQueryTransformLSTM(dim,dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x,gaze_x):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, F, v, D)
        Returns:
            shape (b, T, n, D) where n is self.num_latents
        """
        b, T, F, v = x.shape[:4]

        x = rearrange(
            x, "b T F v d -> b T (F v) d"
        )  # flatten the frame and spatial dimensions


        gaze_x = rearrange(
            gaze_x, "b T F v d -> b T (F v) d"
        )  # flatten the frame and spatial dimensions


        gaze_x= gaze_x.mean(dim=2,keepdim=True)
       # gaze_x = self.query_transform(gaze_x)
        for attn, ff in self.layers:
            attn_weights,attn_score= attn(x,gaze_x)

            gaze_x = attn_score + gaze_x
            gaze_x = ff(gaze_x) + gaze_x
        return attn_weights,self.norm(gaze_x)


class GazeAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8,patches=256,n_queries=1):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        

        self.norm_media = nn.LayerNorm(dim)
        self.norm_gaze = nn.LayerNorm(dim)
        

        self.q_init= nn.Linear(patches,n_queries,bias=False)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)  # Gaze features as queries
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)  # Vision features as keys and values
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
        #self.to_attention = nn.Linear(2048,256,bias=False)

    def forward(self, vision_x, gaze_x):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)

        """

        vision_x = self.norm_media(vision_x)
        q = self.norm_gaze(gaze_x)
        h = self.heads


        q = self.to_q(gaze_x)
        # q= q.permute(0,1,3,2)
        # q= self.q_init(q)
        # q= q.permute(0,1,3,2)

        kv_input = vision_x
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        #getting rid of multi-head attention- not required and makes things complex 
        # q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale
        # print("shape of query is ",q.shape)
        # print("shape of key and value is ",k.shape)

        # attention
        sim = einsum("... i d, ... j d  -> ... i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        # print("shape of attention weights is ",attn.shape)

        out = einsum("... i j, ... j d -> ... i d", attn, v)
        # out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        return attn,self.to_out(out)




    def forward_multi_head(self, vision_x, gaze_x):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)

        """

        vision_x = self.norm_media(vision_x)
        q = self.norm_gaze(gaze_x)
        h = self.heads
        


        q = self.to_q(gaze_x)
        kv_input = vision_x
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        #getting rid of multi-head attention- not required and makes things complex 
        q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale
        # print("shape of query is ",q.shape)
        # print("shape of key and value is ",k.shape)

        # attention
        sim = einsum("... i d, ... j d  -> ... i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        #print("shape of attention weights is ",attn.shape)

        out = einsum("... i j, ... j d -> ... i d", attn, v)
        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        attn = rearrange(attn, "b h t n d -> b t n (h d)", h=h)
        return self.to_attention(attn),self.to_out(out)



#to-do: change this so that output is 256x1, not 16x16
def calculate_gaze_proportions(image, num_patches_vertical, num_patches_horizontal):
    # Assume image is a grayscale image where the intensity represents the gaze focus
    total_gaze = np.sum(image)  # Total sum of pixel values in the image

    # Dimensions for each patch
    patch_height = image.shape[0] // num_patches_vertical
    patch_width = image.shape[1] // num_patches_horizontal

    # Initialize an array to hold the gaze proportion for each patch
    gaze_proportions = np.zeros((num_patches_vertical, num_patches_horizontal))

    for i in range(num_patches_vertical):
        for j in range(num_patches_horizontal):
            # Extract the patch
            patch = image[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width]
            # Sum the pixel values in the patch
            patch_sum = np.sum(patch)
            # Calculate the proportion of total gaze this patch contains
            gaze_proportions[i, j] = patch_sum / total_gaze if total_gaze != 0 else 0

    return gaze_proportions