import torch
import torch.nn as nn
from einops import rearrange


class PositionalRNN(nn.Module):
    def __init__(self, positional_encoding_dim, action_encoding_dim, world_rnn_capacity,
                 hidden_size, rnn_type='LSTM', normalize=False, detach_world_state=False, eps=10e-5):
        super(PositionalRNN, self).__init__()
        self.positional_encoding_dim = positional_encoding_dim
        self.action_encoding_dim = action_encoding_dim
        self.world_rnn_capacity = world_rnn_capacity
        self.input_size = positional_encoding_dim + action_encoding_dim + world_rnn_capacity
        self.hidden_size = hidden_size
        self.rnn_type = rnn_type
        self.normalize = normalize
        self.detach_world_state = detach_world_state
        self.eps = eps
        self.batch_first = True
        if self.rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
            self.rnn.flatten_parameters()
        elif self.rnn_type == 'GRU':
            self.rnn = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
            self.rnn.flatten_parameters()
        else:
            raise NotImplementedError
        self.decoder = nn.Linear(self.hidden_size, self.positional_encoding_dim)

    def forward(self, positional_encodings, action_encodings, world_state, hx=None):
        # Read shapes
        _n, _t, _a, _p = positional_encodings.shape
        assert _p == self.positional_encoding_dim
        _, _, _, _e = action_encodings.shape
        assert action_encodings.shape == (_n, _t, _a, self.action_encoding_dim)
        # World state can either be NTAC or NTC. If it's NTC, expand to NTAC
        if world_state.dim() == 3:
            _, _, _w = world_state.shape
            assert world_state.shape == (_n, _t, self.world_rnn_capacity)
            world_state = world_state[:, :, None, :].expand(_n, _t, _a, self.world_rnn_capacity)
        elif world_state.dim() == 4:
            _, _, _, _w = world_state.shape
            assert world_state.shape == (_n, _t, _a, _w)
        else:
            raise ValueError
        # Detach world_state if required
        if self.detach_world_state:
            world_state = world_state.detach()
        # Concatenate everything
        rnn_input = torch.cat([positional_encodings, action_encodings, world_state], dim=-1)
        # Fold along batch and agent indices
        rnn_input = rearrange(rnn_input, 'n t a c -> (n a) t c')
        # Run the RNN
        rnn_output, hx = self.rnn(rnn_input, hx)
        # Unfold rnn_output to NTAC
        rnn_output = rearrange(rnn_output, '(n a) t c -> n t a c', n=_n, a=_a)
        # Project to another NTAC tensor
        positional_recons = self.decoder(rnn_output)
        if self.normalize:
            positional_recons = \
                positional_recons / torch.norm(positional_recons, p=2, dim=-1, keepdim=True).clamp(self.eps)
        return positional_recons, rnn_output, hx


class PositionalGRU(PositionalRNN):
    def __init__(self, positional_encoding_dim, action_encoding_dim, world_rnn_capacity, hidden_size):
        super(PositionalGRU, self).__init__(positional_encoding_dim, action_encoding_dim, world_rnn_capacity,
                                            hidden_size, rnn_type='GRU')


class PositionalLSTM(PositionalRNN):
    def __init__(self, positional_encoding_dim, action_encoding_dim, world_rnn_capacity, hidden_size):
        super(PositionalLSTM, self).__init__(positional_encoding_dim, action_encoding_dim, world_rnn_capacity,
                                             hidden_size, rnn_type='LSTM')