import numpy as np
import torch
import torch.nn as nn

import transformers
from .bert import BertModel


class PreferenceTransformer(nn.Module):
    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        max_length=None,
        max_ep_len=1000,
        device='cuda',
        **kwargs
    ):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.max_length = max_length
        self.hidden_size = hidden_size
        self.device = device
        self.max_ep_len = max_ep_len
        config = transformers.BertConfig(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            hidden_size=hidden_size,
            **kwargs
        )

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = BertModel(config)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)
        self.output = nn.Linear(self.hidden_size, 1)

    def forward(
        self,
        states,
        actions,
        timesteps,
        attention_mask=None,
        transition_returns=False,
    ):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long).to(
                self.device
            )

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings

        # this makes the sequence look like (s_1, a_1, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = (
            torch.stack((state_embeddings, action_embeddings), dim=1)
            .permute(0, 2, 1, 3)
            .reshape(batch_size, 2 * seq_length, self.hidden_size)
        )

        stacked_inputs = self.embed_ln(stacked_inputs)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = (
            torch.stack((attention_mask, attention_mask), dim=1)
            .permute(0, 2, 1)
            .reshape(batch_size, 2 * seq_length)
        )

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs["last_hidden_state"]

        # reshape x so that the second dimension corresponds to the original
        # states (0), or actions (1); i.e. x[:,0,t] is the token for s_t
        x = x.reshape(
            batch_size,
            seq_length,
            2,
            self.hidden_size,
        ).permute(0, 2, 1, 3)
        
        x = self.output(x.sum(dim=1)).sum(dim=1)

        return x
    
    def _predict_traj_return(self, observations, actions, timesteps, masks):
        if isinstance(observations, np.ndarray):
            observations = torch.from_numpy(observations).to(self.device)
        if isinstance(actions, np.ndarray):
            actions = torch.from_numpy(actions).to(self.device)
        if isinstance(timesteps, np.ndarray):
            timesteps = torch.from_numpy(timesteps).long().to(self.device)
        if isinstance(masks, np.ndarray):
            masks = torch.from_numpy(masks).long().to(self.device)
            
        self.eval()
        ptr = 0
        pref = np.zeros(len(observations))
        while ptr < self.max_ep_len:
            observation = observations[..., ptr: ptr+self.max_length, :]
            action = actions[..., ptr: ptr + self.max_length, :]
            timestep = timesteps[..., ptr: ptr + self.max_length]
            mask = masks[..., ptr: ptr + self.max_length]
            
            pref += (torch.sigmoid(self(observation, action, timestep, mask))[..., 0] * mask[..., 0]).cpu().detach().numpy()
            # pref += self(observation, action, timestep, mask).cpu().detach().numpy()
            ptr += self.max_length
        
        self.train()
        
        return pref

    def predict_traj_return(self, trajs, tlens):
        prefs = []
        
        for tlen, traj in zip(tlens, trajs):
            observations = traj[:, :self.state_dim].reshape(1, -1, self.state_dim)
            actions = traj[:, self.state_dim:].reshape(1, -1, self.act_dim)
            timesteps = torch.range(0, len(traj) - 1).long().reshape(1, -1).to(self.device)
            masks = torch.ones_like(timesteps).long().reshape(1, -1).to(self.device)
            
            timesteps[0, tlen:] *= 0
            masks[0, tlen:] *= 0
            
            prefs.append(self._predict_traj_return(observations, actions, timesteps, masks)[0])
        
        return np.array(prefs)


        
