from torch import nn
import numpy as np
import torch
import torch.nn.functional as F
from base import EncoderLayer
from base import _prepare_decoder_inputs, LayerNorm



class Config(object):
    dropout = 0.0
    attention_dropout = 0.0

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

class AutoPoly(nn.Module):

    def __init__(self, latent_dim = 128, out_dim = 3, device = 'cuda'):
        super().__init__()
 
        self.config = Config(
            max_position_embeddings=21,
            n_embed=128,
            n_layer=3,
            n_head=8,
            ffn_dim=128,
        )
        self.device = device
        self.dropout = self.config.dropout
        self.embed_dim = self.config.n_embed

        self.ffn_dim = self.config.ffn_dim
        self.n_layer = self.config.n_layer
        self.max_len = self.config.max_position_embeddings
        self.latent_dim = latent_dim 

        self.embed_tokens = nn.Linear(2, self.config.n_embed)

        self.start_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.embed_pos = nn.Parameter(self.get_1d_embed(self.embed_dim, torch.arange(self.max_len).unsqueeze(0)), requires_grad=False)
        self.en_fc = nn.Linear(self.embed_dim+self.latent_dim, self.embed_dim)

        self.layers = nn.ModuleList(
            [EncoderLayer(self.config) for _ in range(self.config.n_layer)])
        
        self.layernorm_embedding = LayerNorm(self.embed_dim) 

        self.layer_norm = LayerNorm(self.config.n_embed) 
        
        self.decoder = nn.Linear(self.ffn_dim, out_dim)

        torch.nn.init.normal_(self.start_token, std=.02)


    def get_1d_embed(self, embed_dim, pos):
        pos = pos.cpu().numpy()
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega  # (D/2,)
        batch_n, num_b = pos.shape

        pos = pos.reshape(-1)  # (M,)
        out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        out = out.reshape(batch_n, num_b, embed_dim // 2)

        emb_sin = np.sin(out) # (M, D/2)
        emb_cos = np.cos(out) # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=2)  # (M, D)
        return torch.tensor(emb).to(self.device).float()

    def forward(self, latent, source = None, gen = False):
        """
        Compute the logits for the given source.

        Args:
            source: [bsz, 20, 2]
            latent: [bsz, latent_dim]
            **unused: Additional unused arguments.

        Returns:
            logits: The computed logits.
        """
        if gen == True:
            assert source == None
        bsz = latent.shape[0]

        if gen == True:
            decoder_input_ids = torch.ones(bsz, 1, 2).to(self.device)
        else:
            decoder_input_ids = torch.cat([torch.ones(bsz, 1, 2).to(self.device),source], dim = 1)

        encoder_padding_mask, causal_mask = _prepare_decoder_inputs(
            self.config,
            decoder_input_ids=decoder_input_ids,
        )

        if gen == True:
            inputs_embeds = self.start_token.repeat(bsz, 1, 1)
        else:
            inputs_embeds = torch.cat([self.start_token.repeat(bsz, 1, 1), self.embed_tokens(source)], dim = 1)
        inputs_embeds = self.en_fc(torch.cat([inputs_embeds, latent.unsqueeze(1).repeat(1, inputs_embeds.shape[1], 1)], dim = -1))
        x = inputs_embeds + self.embed_pos[:, :inputs_embeds.shape[1], :]
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        for encoder_layer in self.layers:
            x = encoder_layer(x, attn_mask = causal_mask, encoder_padding_mask = encoder_padding_mask)

        x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        pred = self.decoder(x)

        return pred