import torch
import torch.nn as nn

from transfer.models.model import TrajectoryModel


class MLPModel(TrajectoryModel):
    """Simple MLP that predicts next action a from past states s."""

    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        n_layer,
        dropout=0.1,
        max_length=1,
        use_returns=True,
        use_actions=True,
        **kwargs
    ):
        super().__init__(state_dim, act_dim, max_length=max_length)
        self.use_returns = use_returns
        self.use_actions = use_actions

        self.hidden_size = hidden_size

        self.embed_return = torch.nn.Linear(1, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.n = 1 + self.use_actions * 1 + self.use_returns * 1
        layers = [nn.Linear(max_length * hidden_size * self.n, hidden_size)]
        for _ in range(n_layer - 1):
            layers.extend([nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size, hidden_size)])
        layers.extend(
            [
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, self.act_dim),
                nn.Tanh(),
            ]
        )

        self.model = nn.Sequential(*layers)

    def mask_sequence(self, sequence):
        """We want model to predict action for each sub sequence.

        For example, given sequence of 20 states model should predict 20 actions,
        for the first action model will see 19 pads and 1 state,
        for the last action model will see 20 states.
        When actions and/or returns are also in the sequcence we will pick every second or third sequence.
        For sequence (r_0, s_1, a_1, r_1, s_2, a_2, r_2, s_3, ...)
        we would have (r_0, s_1), (r_0, s_1, a_1, r_1, s_2), ...
        Also when using actions we need to make sure the model don't see target action.
        """
        batch_size, seq_length, seq_dim = sequence.shape
        big_sequence = sequence.unsqueeze(1).repeat(1, seq_length, 1, 1)
        big_padding = torch.zeros((batch_size, seq_length, seq_length, seq_dim), device=sequence.device)
        big_padded_sequence = torch.cat([big_padding, big_sequence], dim=2)
        indexing = (
            torch.arange(seq_length, device=sequence.device).view(1, -1).repeat(seq_length, 1)
            + torch.arange(seq_length, device=sequence.device).view(-1, 1)
            + 1
        )
        batch_indexing = indexing.unsqueeze(0).repeat(batch_size, 1, 1)
        final_sequence = torch.take_along_dim(big_padded_sequence, batch_indexing.unsqueeze(-1), dim=2)
        final_sequence = final_sequence[:, self.n - int(self.use_actions) - 1 :: self.n]
        return final_sequence

    def forward(self, states, actions, rewards, returns_to_go, timesteps=None, attention_mask=None):
        batch_size, seq_length = states.shape[0], states.shape[1]

        state_embeddings = self.embed_state(states)
        if self.use_actions:
            action_embeddings = self.embed_action(actions)
        if self.use_returns:
            returns_embeddings = self.embed_return(returns_to_go)

        if self.use_returns and self.use_actions:
            stacked_inputs = (returns_embeddings, state_embeddings, action_embeddings)
        elif self.use_returns:
            stacked_inputs = (returns_embeddings, state_embeddings)
        elif self.use_actions:
            stacked_inputs = (state_embeddings, action_embeddings)
        else:
            stacked_inputs = (state_embeddings,)

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        stacked_inputs = (
            torch.stack(stacked_inputs, dim=1)
            .permute(0, 2, 1, 3)
            .reshape(batch_size, self.n * seq_length, self.hidden_size)
        )
        masked_inputs = self.mask_sequence(stacked_inputs)
        actions = self.model(masked_inputs.reshape(batch_size * seq_length, -1)).reshape(
            batch_size, seq_length, self.act_dim
        )

        return None, actions, None

    def get_action(self, states, actions, rewards, returns_to_go, timesteps=None, **kwargs):
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)

        states = states[:, -self.max_length :]
        actions = actions[:, -self.max_length :]
        returns_to_go = returns_to_go[:, -self.max_length :]

        if states.shape[1] < self.max_length:
            states = torch.cat(
                [
                    torch.zeros(
                        (1, self.max_length - states.shape[1], self.state_dim),
                        device=states.device,
                    ),
                    states,
                ],
                dim=1,
            )
            actions = torch.cat(
                [
                    torch.zeros(
                        (1, self.max_length - actions.shape[1], self.act_dim),
                        device=actions.device,
                    ),
                    actions,
                ],
                dim=1,
            )
            returns_to_go = torch.cat(
                [
                    torch.zeros(
                        (1, self.max_length - returns_to_go.shape[1], 1),
                        device=returns_to_go.device,
                    ),
                    returns_to_go,
                ],
                dim=1,
            )
        states = states.to(dtype=torch.float32)
        actions = actions.to(dtype=torch.float32)
        returns_to_go = returns_to_go.to(dtype=torch.float32)
        _, actions, _ = self.forward(states, actions, None, returns_to_go, **kwargs)
        return actions[0, -1]
