import numpy as np
import torch
import torch.nn as nn


from models.utils import polyline_encoder
from einops import rearrange
import math

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        self.dim = dim
        self.theta = theta 

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class MTREncoder(nn.Module):
    def __init__(self, config, use_pre_norm, device):
        super().__init__()
        self.model_cfg = config
        dim = self.model_cfg.D_MODEL

        self.agent_polyline_encoder = self.build_polyline_encoder(
            in_channels=self.model_cfg.NUM_INPUT_CONTEXT,
            hidden_dim=self.model_cfg.NUM_CHANNEL_IN_MLP_AGENT,
            num_layers=self.model_cfg.NUM_LAYER_IN_MLP_AGENT,
            out_channels=dim
        ).to(device)

        self.pos_encoding = nn.Sequential(
                SinusoidalPosEmb(dim, theta = 10000),
                nn.Linear(dim, dim),
                nn.ReLU(),
                nn.Linear(dim, dim)
            ).to(device)
        self.team_one_query_embedding = nn.Embedding(1, dim)
        self.team_two_query_embedding = nn.Embedding(1, dim)
        self.ball_query_embedding = nn.Embedding(1, dim)
        self.mlp_pe = nn.Sequential(
            nn.Linear(2*dim, dim),
            # nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        self.layer = nn.TransformerEncoderLayer(d_model=dim, 
                                                dropout=self.model_cfg.get('DROPOUT_OF_ATTN', 0.1),
                                                nhead=self.model_cfg.NUM_ATTN_HEAD, 
                                                dim_feedforward=dim * 4, 
                                                norm_first=use_pre_norm,
                                                batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.layer, num_layers=self.model_cfg.NUM_ATTN_LAYERS)
        self.num_out_channels = dim
        self.max_agents = self.model_cfg.NUM_OF_ATTN_NEIGHBORS
        self.agent_index_embed = nn.Embedding(self.max_agents, dim)

        self.num_agent_types = 5
        self.agent_type_embed = nn.Embedding(self.num_agent_types, dim)

        # [RR_paw, LR_paw, RF_paw, LF_paw, tail_root, head, neck, spine]
        if self.model_cfg.DATA_TYPE == 'rat':
            self.register_buffer(
                "kp_type_ids",
                torch.tensor([0, 0, 1, 1, 4, 2, 2, 3], dtype=torch.long)  # len=8
            )
        elif self.model_cfg.DATA_TYPE == 'babel':
            kp_type_ids = torch.tensor([
                0,  # 0 pelvis
                1, 1,  # 1-2 
                0,  # 3 spine1
                1, 1,  # 4-5 
                0,  # 6 spine2
                1, 1,  # 7-8 
                0,  # 9 spine3
                1, 1,  # 10-11 
                2,  # 12 neck
                3, 3,  # 13-14 
                2,  # 15 head
                3, 3,  # 16-17 
                4, 4, 4, 4  # 18-21
            ], dtype=torch.long)
            self.register_buffer("kp_type_ids", kp_type_ids)


    ### polyline encoder MLP
    def build_polyline_encoder(self, in_channels, hidden_dim, num_layers, num_pre_layers=1, out_channels=None):
        ret_polyline_encoder = polyline_encoder.PointNetPolylineEncoder(
            in_channels=in_channels,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_pre_layers=num_pre_layers,
            out_channels=out_channels
        )

        return ret_polyline_encoder
    
    def agent_query_embedding(self, A, device):
        idx = torch.arange(A, device=device)  # [A]
        index_q = self.agent_index_embed(idx)  # [A, D]

        type_ids = self.kp_type_ids[:A].to(device)  # [A]
        type_q = self.agent_type_embed(type_ids)  # [A, D]

        return index_q + type_q  # [A, D]


    def forward(self, past_traj):
        """
        Args: [Batch size, Number of agents, Number of time frames, 6]

        """
        past_traj_mask = torch.ones_like(past_traj[..., 0], dtype=torch.bool).to(past_traj.device)

        self.agent_polyline_encoder.to(device=past_traj.device)
        obj_polylines_feature = self.agent_polyline_encoder(past_traj, past_traj_mask)  # (num_center_objects, num_objects, C)
        A = obj_polylines_feature.shape[1]
        device = past_traj.device
        
        pos_encoding = self.pos_encoding(torch.arange(A, device=device))

        agent_query = self.agent_query_embedding(A, device)

        pos_encoding = self.mlp_pe(torch.cat([agent_query, pos_encoding], dim=-1)) #[A, D]

        obj_polylines_feature += pos_encoding.unsqueeze(0) #[B, A, D]
        encoder_out = self.transformer_encoder(obj_polylines_feature)
        
        return encoder_out  
