from addict import Dict
from .rnn_mawm import MAWM

from .. import encoder as enc
from .. import decoder as dec
from .. import positional_rnn as prnn
from .. import space_cells as spc
from ..utils import MaskedReconstructor
from ...utils import CodepathNotReadyError


class MATWM(MAWM):
    """Multi-Agent Topological World Model"""
    # Flags (to talk to trainer)
    ROLLOUTS_IMPLEMENTED = True

    def __init__(self, *args, **kwargs):
        use_local_repr_for_positional_rnn = kwargs.pop('use_local_repr_for_positional_rnn')
        super(MATWM, self).__init__(*args, **kwargs)
        self.use_local_repr_for_positional_rnn = use_local_repr_for_positional_rnn

    def forward(self, actions, positions, states, rnn_state=None, positional_rnn_state=None, target_positions=None):
        # Unlike the super-class, the encoder does not use the positional embeddings as an input.
        action_embeddings, position_embeddings, state_encoding, state_action_encoding = \
            self.embed_and_encode(actions, positions, states)
        # ... instead, we inject the postion information in to the RNN.
        rnn_output, *rnn_state = self.rnn(state_action_encoding, position_embeddings, rnn_state)
        if target_positions is not None:
            assert self.positional_rnn is None
            target_positional_embeddings = self.embed_positions(target_positions)
        else:
            target_positional_embeddings = position_embeddings
        recons, predicted_positional_embeddings, positional_rnn_output, mask_info = \
            self.reconstruct(target_positional_embeddings, action_embeddings, rnn_output, positional_rnn_state)
        outputs = Dict()
        outputs.action_embeddings = action_embeddings
        outputs.position_embeddings = position_embeddings
        outputs.state_encoding = state_encoding
        outputs.state_action_encoding = state_action_encoding
        outputs.rnn_output = rnn_output
        outputs.rnn_state = rnn_state
        outputs.recons = recons
        outputs.mask_info = mask_info
        outputs.predicted_position_embeddings = (predicted_positional_embeddings[:, :-1]
                                                 if predicted_positional_embeddings is not None else None)
        outputs.target_position_embeddings = position_embeddings[:, 1:]
        outputs.positional_rnn_output = positional_rnn_output
        outputs.positional_rnn_state = positional_rnn_state
        return outputs

    def predict_positional_embeddings(self, position_embeddings, action_embeddings, rnn_output, positional_rnn_state):
        # Run the positional RNN (if required)
        if self.positional_rnn is not None:
            if self.use_local_repr_for_positional_rnn:
                # Query the primary RNN for outputs at current positions, which is then fed to the positional RNN
                # to predict the next step. This imposes the inductive bias
                rnn_output = self.rnn.query(rnn_output, position_embeddings)
            else:
                # The positional rnn gets the global state (i.e. one from each RIM).
                # This is one less inductive bias which may or may not generalize.
                rnn_output = self.rnn.flatten_hiddens(rnn_output)
            predicted_positional_embeddings, positional_rnn_output, positional_rnn_state = \
                self.positional_rnn(position_embeddings, action_embeddings, rnn_output, positional_rnn_state)
        else:
            predicted_positional_embeddings, positional_rnn_output, positional_rnn_state = None, None, None
        return predicted_positional_embeddings, positional_rnn_output

    def decode(self, predicted_positional_embeddings, position_embeddings, rnn_output):
        # See super-class for detailed comments.
        if predicted_positional_embeddings is not None and self.decode_with_predicted_positional_encodings:
            assert self.positional_encoder.normalize and self.positional_rnn.normalize
            cropped_positional_embeddings = (predicted_positional_embeddings.detach()
                                             if self.detach_positional_encodings_before_decoding else
                                             predicted_positional_embeddings)[:, :-1]
        else:
            cropped_positional_embeddings = position_embeddings[:, 1:]
        # Query the RNN for position dependent representations (NTAC). But rememeber to temporally crop the rnn_output,
        # which happens to be a MNTC tensor.
        position_aware_rnn_outputs = self.rnn.query(rnn_output[:, :, :-1], cropped_positional_embeddings)
        # Pass through the generative model and hope for the best
        recons = self.decoder(position_aware_rnn_outputs)
        # recons is a NTAChw tensor
        return recons

    def one_step_rollout(self, state_in, is_initial=False):
        # These are required to roll-out
        if self.positional_rnn is not None:
            assert self.positional_rnn.normalize and self.positional_encoder.normalize
        # See method in super-class for more comments.
        # Embed and encode actions, positions and states at time t. These are all N1AC tensors.
        action_embeddings, position_embeddings, state_encoding, state_action_encoding, *remaining_embeddings = \
            self.embed_and_encode(actions=state_in.actions, positions=state_in.positions, states=state_in.states,
                                  goals=state_in.goals, position_is_embedding=(not is_initial),
                                  goal_is_embedding=(not is_initial))
        goal_embeddings = remaining_embeddings[0] if len(remaining_embeddings) > 0 else None
        state_in.positions = position_embeddings
        state_in.goals = goal_embeddings
        assert action_embeddings.shape[1] == position_embeddings.shape[1] == state_encoding.shape[1] == 1
        # Run through RNN with state_action encoding (N1AC) and positional embeddings (also N1AC)
        rnn_output, *rnn_state_next = self.rnn(state_action_encoding, position_embeddings,
                                               *(state_in.rnn_state or ()))
        if self.positional_rnn is not None:
            # Run through the positional RNN
            if self.use_local_repr_for_positional_rnn:
                rnn_output_for_prnn = self.rnn.query(rnn_output, position_embeddings)
            else:
                rnn_output_for_prnn = self.rnn.flatten_hiddens(rnn_output)
            position_embeddings_next, _, *positional_rnn_state_next = \
                self.positional_rnn(position_embeddings, action_embeddings, rnn_output_for_prnn,
                                    *(state_in.positional_rnn_state or ()))
        else:
            assert 'next_positions' in state_in
            position_embeddings_next = self.embed_positions(positions=state_in.next_positions)
            positional_rnn_state_next = ()
        # Use the next positional embedding to first query the RNN, and then predict the next state
        rnn_output_for_decoder = self.rnn.query(rnn_output, position_embeddings_next)
        reconstruction = self.decoder(rnn_output_for_decoder)
        if self.masker is not None:
            raise CodepathNotReadyError
            # noinspection PyUnreachableCode
            reconstruction, _ = self.masker(reconstruction)
        states_next = self.recon_to_state_processor(reconstruction, state_in)
        # Build output state
        state_out = Dict(actions=None, positions=position_embeddings_next, states=states_next,
                         rnn_state=rnn_state_next, positional_rnn_state=positional_rnn_state_next,
                         goals=goal_embeddings, recons=reconstruction, rnn_output=rnn_output)
        return state_out


class MATWMWithRNN(MATWM):
    def __init__(self, state_size=(5, 5), num_in_state_channels=23, num_out_state_channels=21, polar_states=False,
                 positional_encoding_dim=16, num_position_dims=2, positional_encoding_max_frequency=1000,
                 normalize_positional_encodings=True, trainable_positional_embeddings=False,
                 num_actions=5, action_encoding_dim=16, encoder_type='tower', encoder_capacity=128, repr_dim=128,
                 rnn_num_cells=10, rnn_cell_capacity=128, rnn_type='SpaceGRU', rnn_kwargs='DEFAULT',
                 decoder_type='residual', decoder_capacity=128, decoder_kwargs=None, mask_spec=None,
                 positional_rnn_type=None, positional_rnn_capacity=128, positional_rnn_kwargs=None,
                 decode_with_predicted_positional_encodings=False, detach_positional_encodings_before_decoding=True,
                 use_local_repr_for_positional_rnn=True):
        positional_encoder = enc.PositionalEncoding(positional_encoding_dim, num_position_dims,
                                                    positional_encoding_max_frequency,
                                                    normalize_positional_encodings,
                                                    trainable_positional_embeddings)
        if action_encoding_dim is None or action_encoding_dim == 0:
            action_encoder = None
            action_encoding_dim = 0
        else:
            action_encoder = enc.ActionEncoding(num_actions, action_encoding_dim)
        if encoder_type == 'tower':
            encoder = enc.ContextlessTowerEncoder(state_size, num_in_state_channels,
                                                  encoder_capacity, repr_dim, polar_states)
        else:
            raise NotImplementedError
        rnn_input_size = repr_dim + action_encoding_dim
        rnn_kwargs = spc.KWARGSETS[rnn_kwargs] if isinstance(rnn_kwargs, str) else rnn_kwargs
        assert isinstance(rnn_kwargs, dict)
        if rnn_type == 'SpaceGRU':
            rnn = spc.SpaceGRU(input_size=rnn_input_size, num_cells=rnn_num_cells,
                               cell_hidden_size=rnn_cell_capacity, **rnn_kwargs)
            rnn_capacity = rnn_cell_capacity
        elif rnn_type == 'SpaceLSTM':
            rnn = spc.SpaceLSTM(input_size=rnn_input_size, num_cells=rnn_num_cells,
                                cell_hidden_size=rnn_cell_capacity, **rnn_kwargs)
            rnn_capacity = rnn_cell_capacity
        elif rnn_type == 'SpaceRMC':
            rnn = spc.SpaceRMC(input_size=rnn_input_size, num_cells=rnn_num_cells,
                               **rnn_kwargs)
            rnn_capacity = rnn_cell_capacity = rnn.space_cells.cell_hidden_size
        else:
            raise NotImplementedError
        if mask_spec is not None:
            raise CodepathNotReadyError
            # noinspection PyUnreachableCode
            masker = MaskedReconstructor(mask_spec, channel_dim=3)
            num_out_state_channels += masker.num_mask_channels
        else:
            masker = None
        if decoder_type == 'residual':
            decoder = dec.ContextlessResidualDecoder(state_size, num_out_state_channels,
                                                     decoder_capacity, rnn_cell_capacity,
                                                     polar_states)
        elif decoder_type == 'transformer':
            decoder = dec.ContextlessTransformerDecoder(state_size, num_out_state_channels,
                                                        decoder_capacity, rnn_cell_capacity,
                                                        **(decoder_kwargs or {}))
        else:
            raise NotImplementedError
        if positional_rnn_type is not None:
            positional_rnn = prnn.PositionalRNN(positional_encoding_dim=positional_encoding_dim,
                                                action_encoding_dim=action_encoding_dim,
                                                world_rnn_capacity=rnn_capacity,
                                                hidden_size=positional_rnn_capacity,
                                                rnn_type=positional_rnn_type,
                                                **(positional_rnn_kwargs or {}))
        else:
            positional_rnn = None
        super(MATWMWithRNN, self).__init__(encoder, action_encoder, positional_encoder, None, rnn, decoder,
                                           positional_rnn, decode_with_predicted_positional_encodings,
                                           detach_positional_encodings_before_decoding,
                                           use_local_repr_for_positional_rnn=use_local_repr_for_positional_rnn,
                                           masker=masker)


