import torch
import torch.nn as nn
import transformers
from diffuser.models.bert import BertModel
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class EncoderTransformer(nn.Module):
    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        output_size,
        max_ep_len=4096,
        repre_type='vec',
        **kwargs
    ):
        super().__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim

        self.hidden_size = hidden_size
        config = transformers.BertConfig(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            hidden_size=hidden_size,
            **kwargs
        )
        self.output_size = output_size

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = BertModel(config)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)
        self.repre_type = repre_type
        if self.repre_type == 'vec':
            self.to_phi = nn.Linear(self.hidden_size, self.output_size)
        elif self.repre_type == 'dist':
            self.to_phi_mean = nn.Linear(self.hidden_size, self.output_size)
            self.to_phi_std = nn.Linear(self.hidden_size, self.output_size)
        elif self.repre_type == 'vq_vec':
            self.to_phi = nn.Linear(self.hidden_size, self.output_size)
            self.vq_embed = VectorQuantizer(n_e=self.output_size, e_dim=64, beta=0.25)

    def forward(self, states, actions, timesteps, attention_mask=None):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long).to(DEVICE)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings # (32,20,128)
        action_embeddings = action_embeddings + time_embeddings # (32,20,128)
        
        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = (
            torch.stack((state_embeddings, action_embeddings), dim=1) # (32,2, 20, 128)
            .permute(0, 2, 1, 3) # (32,20, 2,128)
            .reshape(batch_size, 2 * seq_length, self.hidden_size) #(32,40,128)
        )

        stacked_inputs = self.embed_ln(stacked_inputs)# (32,40,128)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = (
            torch.stack((attention_mask, attention_mask), dim=1)
            .permute(0, 2, 1)
            .reshape(batch_size, 2 * seq_length)
        )# (32,40)
        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs["last_hidden_state"] # (32,40, 128)

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3) # (32,2,20, 128)

        x = x.sum(dim=2).sum(dim=1) # (32,128)

        if self.repre_type == 'vec':
            return self.to_phi(x) # (32,16)
        elif self.repre_type == 'dist':
            std = torch.clamp(self.to_phi_std(x), min=-5, max=2)
            return self.to_phi_mean(x), std
        elif self.repre_type == 'vq_vec':
            z_e = self.to_phi(x)
            z_q, z_out = self.vq_embed(z_e)
            return z_e, z_q, z_out

class VectorQuantizer(nn.Module):
    """
    Discretization bottleneck part of the VQ-VAE. Thanks to https://github.com/MishaLaskin/vqvae.

    Inputs:
    - n_e : number of embeddings, 128
    - e_dim : dimension of embedding, 64
    - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2, 0.25
    """

    def __init__(self, n_e, e_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

    def forward(self, z):
        """
        Inputs the output of the encoder network z and maps it to a discrete 
        one-hot vector that is the index of the closest embedding vector e_j

        z (continuous) -> z_q (discrete)

        z.shape = (batch, channel, height, width)

        quantization pipeline:

            1. get encoder input (B,C,H,W)
            2. flatten input to (B*H*W,C)

        """
        # reshape z -> (batch, height, width, channel) and flatten
        # z = z.permute(0, 2, 3, 1).contiguous()
        z = z.contiguous() # batch, hidden_size=128
        z_flattened = z.view(-1, self.e_dim)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.n_e).to(DEVICE)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)

        # compute loss for embedding
        # loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
            # torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_out = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        # reshape back to match original input shape
        # z_q = z_q.permute(0, 3, 1, 2).contiguous()
        # z_q = z_q.contiguous()
        z_out = z_out.contiguous()

        return z_q, z_out

        # return loss, z_q, perplexity, min_encodings, min_encoding_indices
