import numpy as np
import torch
import torch.nn as nn
from utils.nets import BasicNeuralNetwork
from transformers import BertConfig
from agent.bertmodel import BertModel
from agent.params import PARAMS
from utils.common import get_device
from typing import Optional
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from agent.args import DEVICE3

DEVICE = DEVICE3

# MAX_STEP_LEN = PARAMS['pre_steps'] + PARAMS['post_steps'] + 1
# HIDDEN_SIZE = PARAMS['transformer']['hidden_size']

# MASK = torch.ones((HIDDEN_SIZE, ), dtype=torch.float32, device=DEVICE) * -100
# EMPTY = torch.ones((HIDDEN_SIZE, ), dtype=torch.float32, device=DEVICE) * -200


class BertTransformer(BasicNeuralNetwork):
    def __init__(
        self,
        state_dim,
        act_dim,
        # max_step_len,
    ):
        super(BertTransformer, self).__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.hidden_size = PARAMS["transformer"]["hidden_size"]
        self.max_step_len = PARAMS["transformer"]["max_step_len"]
        self.num_centers = PARAMS["num_centers"]

        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(
                d_model=self.hidden_size,
                nhead=PARAMS["transformer"]["nhead"],
                dim_feedforward=self.hidden_size * 4,
                activation="gelu",
                batch_first=True,
            ),
            num_layers=PARAMS["transformer"]["num_layers"],
        ).to(DEVICE)

        self.embed_timestep = nn.Embedding(self.max_step_len, self.hidden_size).to(
            DEVICE
        )
        self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size).to(DEVICE)

        self.embed_ln = nn.LayerNorm(self.hidden_size).to(DEVICE)
        # self.to_act = nn.Linear(self.hidden_size, self.act_dim).to(DEVICE)
        self.to_centroids = nn.Sequential(
            nn.Linear(self.hidden_size, self.num_centers), nn.LogSoftmax(dim=-1)
        ).to(DEVICE)

        self.to_offsets = nn.Linear(
            self.hidden_size, self.num_centers * self.act_dim
        ).to(DEVICE)
        self.mask = nn.Embedding(1, self.hidden_size).to(DEVICE)

    def forward(
        self,
        states,
        masks=None,
        # actions,
        # timesteps,
        # attention_mask=None,
        # mask_last: Optional[int] = None):
    ):
        batch_size, seq_length, state_dim = states.shape
        assert state_dim == self.state_dim
        if masks is not None:
            assert masks.shape == states.shape[:-1]
            assert masks.dtype == torch.bool
            assert len(masks.sum(dim=-1).unique()) == 1

        assert seq_length == self.max_step_len

        # if mask_last is None:
        # attention mask for GPT: 1 if can be attended to, 0 if not
        # attention_mask = torch.ones(
        #     (batch_size, seq_length), dtype=torch.long, device=DEVICE
        # )
        # else:
        # attention_mask = torch.concat((torch.ones(
        #     (batch_size, seq_length - mask_last),
        #     dtype=torch.long,
        #     device=DEVICE),
        #                                torch.zeros((batch_size, mask_last),
        #                                            dtype=torch.long,
        #                                            device=DEVICE)))

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)

        if masks is not None:
            _mask = self.mask(torch.tensor(0, dtype=torch.long, device=DEVICE))
            state_embeddings[masks] = _mask
        # action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(
            torch.arange(self.max_step_len, device=DEVICE)
        )

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        # action_embeddings = action_embeddings + time_embeddings

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        # stacked_inputs = (torch.stack(
        #     (state_embeddings, action_embeddings),
        #     dim=1).permute(0, 2, 1, 3).reshape(batch_size, 2 * seq_length,
        #                                        self.hidden_size))

        inputs = self.embed_ln(state_embeddings)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        # stacked_attention_mask = (torch.stack(
        #     (attention_mask, attention_mask),
        #     dim=1).permute(0, 2, 1).reshape(batch_size, 2 * seq_length))

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs,
            # attention_mask=attention_mask,
        )
        x = transformer_outputs

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        # x = x.reshape(
        #     batch_size,
        #     seq_length,
        #     2,
        #     self.hidden_size,
        # ).permute(0, 2, 1, 3)

        # x = x.sum(dim=2).sum(dim=1)
        x = x[:, :-1]

        # act = self.to_act(x).clip(-1., 1.)
        centroids = self.to_centroids(x)
        offsets = self.to_offsets(x).reshape(
            (batch_size, seq_length - 1, self.num_centers, self.act_dim)
        )

        return centroids, offsets
