import torch
import torch.nn as nn
import transformers
transformers.set_seed(0)
from transformers import GPT2Config, GPT2Model
from IPython import embed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Transformer(nn.Module):
    """Transformer class."""

    def __init__(self, config):
        super(Transformer, self).__init__()

        self.config = config
        self.test = config['test']
        self.horizon = self.config['horizon']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.state_dim = self.config['state_dim']
        self.action_dim = self.config['action_dim']
        self.dropout = self.config['dropout']

        config = GPT2Config(
            n_positions=4 * (1 + self.horizon),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        # self.embed_transition = nn.Linear(
        #     2 * self.state_dim + self.action_dim + 1, self.n_embd)
        self.embed_transition = nn.Linear(
            2 * self.state_dim + 2 * self.action_dim, self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.action_dim)
        # initialize weights for the pred_actions layer
        nn.init.normal_(self.pred_actions.weight, mean=0.0, std=1e-3)
        nn.init.zeros_(self.pred_actions.bias)        

    def forward(self, x):
        query_states = x['query_states'][:, None, :]
        zeros = x['zeros'][:, None, :]

        state_seq = torch.cat([query_states, x['context_states']], dim=1)
        pref_actions_seq = torch.cat([zeros[:, :, :self.action_dim], x['preferred_actions']], dim=1)
        npref_actions_seq = torch.cat([zeros[:, :, :self.action_dim], x['non_preferred_actions']], dim=1)
        next_state_seq = torch.cat([zeros[:, :, :self.state_dim], x['context_next_states']], dim=1)
        seq = torch.cat(
            [state_seq, pref_actions_seq, npref_actions_seq, next_state_seq], dim=2)
        stacked_inputs = self.embed_transition(seq)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_actions(transformer_outputs['last_hidden_state'])

        if self.test:
            return preds[:, -1, :]
        return preds[:, 1:, :]

