from torch import nn
import numpy as np
import torch
import torch.nn.functional as F
from base import EncoderLayer
from base import _prepare_decoder_inputs, LayerNorm


class Config(object):
    dropout = 0.0
    attention_dropout = 0.0

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

class PolyEmbed(nn.Module):

    def __init__(self, ouput_dim = 256, max_position_embeddings = 21, device = 'cuda'):
        super().__init__()
 
        self.config = Config(
            max_position_embeddings=max_position_embeddings,
            n_embed=128,
            n_layer=1,
            n_head=8,
            ffn_dim=128,
        )
        self.device = device
        self.dropout = self.config.dropout
        self.embed_dim = self.config.n_embed

        self.ffn_dim = self.config.ffn_dim
        self.n_layer = self.config.n_layer
        self.max_len = self.config.max_position_embeddings

        self.embed_tokens = nn.Linear(2, self.config.n_embed)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.embed_pos = nn.Parameter(self.get_1d_embed(self.embed_dim, torch.arange(self.max_len).unsqueeze(0)), requires_grad=False)

        self.layers = nn.ModuleList(
            [EncoderLayer(self.config) for _ in range(self.config.n_layer)])
        
        self.layernorm_embedding = LayerNorm(self.embed_dim) 

        self.layer_norm = LayerNorm(self.config.n_embed) 
        
        self.decoder = nn.Linear(self.ffn_dim, ouput_dim)

        torch.nn.init.normal_(self.cls_token, std=.02)


    def get_1d_embed(self, embed_dim, pos):
        pos = pos.cpu().numpy()
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega  # (D/2,)
        batch_n, num_b = pos.shape

        pos = pos.reshape(-1)  # (M,)
        out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        out = out.reshape(batch_n, num_b, embed_dim // 2)

        emb_sin = np.sin(out) # (M, D/2)
        emb_cos = np.cos(out) # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=2)  # (M, D)
        return torch.tensor(emb).to(self.device).float()

    def forward(self, source):
        """
        Compute the logits for the given source.

        Args:
            source: [bsz, 20, 2]
            latent: [bsz, latent_dim]
            **unused: Additional unused arguments.

        Returns:
            logits: The computed logits.
        """
        bsz = source.shape[0]
        decoder_input_ids = torch.cat([torch.ones(bsz, 1, 2).to(self.device),source], dim = 1)
        encoder_padding_mask, _ = _prepare_decoder_inputs(
            self.config,
            decoder_input_ids=decoder_input_ids,
        )
    
        inputs_embeds = torch.cat([self.cls_token.repeat(bsz, 1, 1), self.embed_tokens(source)], dim = 1)
        x = inputs_embeds + self.embed_pos
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        for encoder_layer in self.layers:
            x = encoder_layer(x, attn_mask = None, encoder_padding_mask = encoder_padding_mask)

        x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        embed = self.decoder(x[:, 0, :])

        return embed
