import math

import torch
from torch import nn

class HeteroRelativePoseEmbedding(nn.Module):
    def __init__(self, dim, learnable, per_degree, per_dist):
        super(HeteroRelativePoseEmbedding, self).__init__()
        embed_dim = dim // 2

        omega = torch.arange(0, embed_dim, dtype=torch.float)
        omega /= embed_dim
        omega = 1.0 / 10000**omega
        
        max_degree = 360. / per_degree + 1.
        max_distance = 70. / per_dist + 1.
        
        emb_arange = torch.arange(0., max_degree) if max_degree > max_distance \
            else torch.arange(0., max_distance)

        emb_out = torch.einsum("m,d->md", emb_arange, omega)

        emb = torch.zeros((2, int(max_degree), embed_dim)) if max_degree > max_distance \
            else torch.zeros((2, int(max_distance), embed_dim))
            
        sin_emb = torch.sin(emb_out)
        cos_emb = torch.cos(emb_out)
        
        emb[0, :, 0::2] = sin_emb[:, 0::2]
        emb[0, :, 1::2] = cos_emb[:, 0::2]
        
        emb[1, :, 0::2] = sin_emb[:, 1::2]
        emb[1, :, 1::2] = cos_emb[:, 1::2]
        
        emb.requires_grad = False
        
        self.emb = emb
        
        self.learnable = learnable
        
        if self.learnable:
            self.lin_i = nn.Linear(dim, dim)
            self.lin_v = nn.Linear(dim, dim)

    def forward(self, x, infra, degree, dist):
        _, _, C = x.shape
        
        degree_emb = self.emb[infra, degree]
        dist_emb = self.emb[infra, dist]
        
        emb = torch.zeros(C)
        emb[0::4], emb[1::4] = degree_emb[0::2], degree_emb[1::2]
        emb[2::4], emb[3::4] = dist_emb[0::2], dist_emb[1::2]
        emb.requires_grad = False
        
        # x += emb.to(x.device).unsqueeze(0).unsqueeze((0))
        
        if self.learnable:
            x += self.lin_i(emb.to(x.device)).unsqueeze(0).unsqueeze((0)) if infra == 1\
                else self.lin_v(emb.to(x.device)).unsqueeze(0).unsqueeze((0))
        
        else:
            x += emb.to(x.device).unsqueeze(0).unsqueeze(0)
        
        return x 
    
class HRPE(nn.Module):
    def __init__(self, dim, learnable, per_degree, per_dist):
        super(HRPE, self).__init__()
        self.per_degree = per_degree
        self.per_dist = per_dist
        self.emb = HeteroRelativePoseEmbedding(dim, learnable, per_degree, per_dist)  
        
    def make_embed_angel(self, deg):
        deg += 360
        deg %= 360
        return deg / self.per_degree
    
    def forward(self, x, infra, hrpe_comp):
        B, L, _, _, _ = x.shape
        
        distance = hrpe_comp[:, :, 0]
        relative_angle = hrpe_comp[:, :, 1]
        
        for b in range(B):
            for l in range(L):
                degree = relative_angle[b, l] / self.per_degree
                dist = distance[b, l] / self.per_dist
                
                x[b, l, :, :, :] = self.emb(x[b, l, :, :, :], int(infra[b, l]), int(degree), int(dist))
        return x 
    
class RelativePoseEmbedding(nn.Module):
    def __init__(self, dim, per_degree, per_dist):
        super(RelativePoseEmbedding, self).__init__()
        embed_dim = dim // 4

        omega = torch.arange(0, embed_dim, dtype=torch.float)
        omega /= embed_dim
        omega = 1.0 / 10000**omega
        
        max_degree = 360. / per_degree
        max_distance = 75. / per_dist
        
        emb_arange = torch.arange(0., max_degree) if max_degree > max_distance \
            else torch.arange(0., max_distance)

        emb_out = torch.einsum("m,d->md", emb_arange, omega)

        emb = torch.zeros((int(max_degree), embed_dim * 2)) if max_degree > max_distance \
            else torch.zeros((int(max_distance), embed_dim * 2))
            
        sin_emb = torch.sin(emb_out)
        cos_emb = torch.cos(emb_out)
        
        emb[:, 0::2] = sin_emb
        emb[:, 1::2] = cos_emb
        
        emb.requires_grad = False
        
        self.emb = emb
        self.lin = nn.Linear(dim, dim)

    def forward(self, x, degree, dist):
        _, _, C = x.shape
        
        degree_emb = self.emb[degree]
        dist_emb = self.emb[dist]
        
        emb = torch.zeros(C)
        emb[0::4], emb[1::4] = degree_emb[0::2], degree_emb[1::2]
        emb[2::4], emb[3::4] = dist_emb[0::2], dist_emb[1::2]
        emb.requires_grad = False
        
        x += emb.to(x.device).unsqueeze(0).unsqueeze((0))
        
        return x 
    
class RPE(nn.Module):
    def __init__(self, dim, per_degree, per_dist):
        super(RPE, self).__init__()
        self.per_degree = per_degree
        self.per_dist = per_dist
        self.emb = RelativePoseEmbedding(dim, per_degree, per_dist)  
        
    def make_embed_angel(self, deg):
        deg += 360
        deg %= 360
        return deg / self.per_degree
    
    def forward(self, x, hrpe_comp):
        B, L, _, _, _ = x.shape
        
        distance = hrpe_comp[:, :, 0]
        relative_angle = hrpe_comp[:, :, 1]
        
        for b in range(B):
            for l in range(L):
                degree = relative_angle[b, l] / self.per_degree
                dist = distance[b, l] / self.per_dist
                
                x[b, l, :, :, :] = self.emb(x[b, l, :, :, :], int(degree), int(dist))
        return x 