# Reference:
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py

import math

import numpy as np
import torch
from torch import nn


def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)


class PositionalEncoding(nn.Module):
    def __init__(self,
                 dim: int,
                 scale: float = math.pi * 2,
                 temperature: float = 10000,
                 normalize: bool = True,
                 channel_last: bool = True,
                 transpose_output: bool = False):
        super().__init__()
        dim = int(np.ceil(dim / 4) * 2)
        self.dim = dim
        inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.normalize = normalize
        self.scale = scale
        self.eps = 1e-6
        self.channel_last = channel_last
        self.transpose_output = transpose_output

        self.cached_penc = None  # the cache is irrespective of the number of objects

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        :param tensor: A 4/5d tensor of size 
            channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
            channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
        :return: positional encoding tensor that has the same shape as the input if the input is 4d
                 if the input is 5d, the output is broadcastable along the k-dimension
        """
        if len(tensor.shape) != 4 and len(tensor.shape) != 5:
            raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')

        if len(tensor.shape) == 5:
            # take a sample from the k dimension
            num_objects = tensor.shape[1]
            tensor = tensor[:, 0]
        else:
            num_objects = None

        if self.channel_last:
            batch_size, h, w, c = tensor.shape
        else:
            batch_size, c, h, w = tensor.shape

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            if num_objects is None:
                return self.cached_penc
            else:
                return self.cached_penc.unsqueeze(1)

        self.cached_penc = None

        pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
        pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
        if self.normalize:
            pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
            pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale

        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_y = get_emb(sin_inp_y).unsqueeze(1)
        emb_x = get_emb(sin_inp_x)

        emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
        emb[:, :, :self.dim] = emb_x
        emb[:, :, self.dim:] = emb_y

        if not self.channel_last and self.transpose_output:
            # cancelled out
            pass
        elif (not self.channel_last) or (self.transpose_output):
            emb = emb.permute(2, 0, 1)

        self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
        if num_objects is None:
            return self.cached_penc
        else:
            return self.cached_penc.unsqueeze(1)


if __name__ == '__main__':
    pe = PositionalEncoding(8).cuda()
    input = torch.ones((1, 8, 8, 8)).cuda()
    output = pe(input)
    # print(output)
    print(output[0, :, 0, 0])
    print(output[0, :, 0, 5])
    print(output[0, 0, :, 0])
    print(output[0, 0, 0, :])
