import torch
import torch.nn as nn
from ...utils import flatten_spatial, CodepathNotReadyError
from addict import Dict
from contextlib import contextmanager

from .. import encoder as enc
from .. import decoder as dec
from .. import positional_rnn as prnn
from .. import rims
from .. import rmc
from .. import tto
from .. import space_cells as spc
from ..utils import MaskedReconstructor
from ..external.block_wrapper import BlocksWrapper


class MAWM(nn.Module):
    # Flags
    ROLLOUTS_IMPLEMENTED = True

    def __init__(self, encoder, action_encoder, positional_encoder, aggregator, rnn, decoder,
                 positional_rnn=None, decode_with_predicted_positional_encodings=False,
                 detach_positional_encodings_before_decoding=False,
                 recon_to_state_processor=(lambda x, *_, **__: x.sigmoid()), masker=None):
        super(MAWM, self).__init__()
        self.encoder = encoder
        self.action_encoder = action_encoder
        self.positional_encoder = positional_encoder
        self.aggregator = aggregator
        self.rnn = rnn
        self.decoder = decoder
        self.positional_rnn = positional_rnn
        self.decode_with_predicted_positional_encodings = decode_with_predicted_positional_encodings
        self.detach_positional_encodings_before_decoding = detach_positional_encodings_before_decoding
        self.recon_to_state_processor = recon_to_state_processor
        self.masker = masker
        # Validate
        assert self.rnn.batch_first
        if self.positional_rnn is not None:
            assert self.positional_rnn.batch_first

    def embed_and_encode(self, actions, positions, states, goals=None,
                         position_is_embedding=False, goal_is_embedding=False):
        # The ops in this method are "pointwise" w.r.t. time and agents.
        # Embed actions to a NTAC tensor
        if self.action_encoder is not None:
            action_embeddings = self.action_encoder(actions)
        else:
            # Now, action_encoder can be None, in which case the states and actions are encoded jointly by
            # `self.encoder`. In this case, we can get away with having an action embedding tensor of shape NTA0
            N, T, A, *_ = actions.shape
            action_embeddings = torch.empty(N, T, A, 0, dtype=actions.dtype, device=actions.device)
        # Embed positions
        position_embeddings = self.embed_positions(positions, position_is_embedding)
        # Embed goals from NA2 to NAC if required
        if goals is not None and goals.numel() > 0:
            if goal_is_embedding:
                goal_embeddings = goals
            else:
                # goals is NA2; convert to NTA2, then to NTAC and then back to NAC
                goal_embeddings = self.positional_encoder(goals[:, None, :, :])[:, 0, :, :]
        else:
            goal_embeddings = None
        if self.action_encoder is not None:
            # Encode state to a NTAC11 tensor given a state NTAChw tensor and the positional embeddings
            state_encoding = self.encoder(states, position_embeddings)
            assert action_embeddings is not None, "This shouldn't happen; but guess what happened."
            # Flatten state_encoding to NTAC and cat along channels with the action embeddings
            state_action_encoding = torch.cat([flatten_spatial(state_encoding), action_embeddings], dim=3)
        else:
            # Action encoder is None. This could mean that actions are not available, or it could mean that
            # the (state) encoder jointly encodes states and actions.
            if actions.numel() == 0:
                # `actions` is an empty tensor, so we assume they're not available
                state_actions = states
            else:
                # We assume `actions` is a NTAChw tensor (like the state), and concatenate it
                # with state along the channel axis.
                state_actions = torch.cat([states, actions], dim=3)
            state_action_encoding = self.encoder(state_actions, position_embeddings)
            # We use the same embedding for state and state_actions.
            state_encoding = state_action_encoding = flatten_spatial(state_action_encoding)
        if goal_embeddings is None:
            return action_embeddings, position_embeddings, state_encoding, state_action_encoding
        else:
            return action_embeddings, position_embeddings, state_encoding, state_action_encoding, goal_embeddings

    def embed_positions(self, positions, position_is_embedding=False):
        # Embed positions to another NTAC tensor (if required)
        if not position_is_embedding:
            position_embeddings = self.positional_encoder(positions)
        else:
            # In this case, assume that `positions` is actually embedding.
            position_embeddings = positions
        return position_embeddings

    def forward(self, actions, positions, states, rnn_state=None, positional_rnn_state=None, target_positions=None):
        action_embeddings, position_embeddings, state_encoding, state_action_encoding = \
            self.embed_and_encode(actions, positions, states)
        # Aggregate along the agent axis to obtain a NTC tensor
        rnn_input = self.aggregator(state_action_encoding)
        # Process to another NTC tensor and states.
        rnn_output, *rnn_state = self.rnn(rnn_input, rnn_state)
        if target_positions is not None:
            assert self.positional_rnn is None, "Positional RNN not supported when target_positions are provided."
            target_positional_embeddings = self.embed_positions(target_positions)
        else:
            target_positional_embeddings = position_embeddings
        # Obtain reconstruction (use positional embeddings if required)
        recons, predicted_positional_embeddings, positional_rnn_output, mask_info = \
            self.reconstruct(target_positional_embeddings, action_embeddings, rnn_output, positional_rnn_state)
        # Return diagnostics if required
        outputs = Dict()
        outputs.action_embeddings = action_embeddings
        outputs.position_embeddings = position_embeddings
        outputs.provided_target_position_embeddings = target_positional_embeddings
        outputs.state_encoding = state_encoding
        outputs.state_action_encoding = state_action_encoding
        outputs.rnn_input = rnn_input
        outputs.rnn_output = rnn_output
        outputs.rnn_state = rnn_state
        outputs.recons = recons
        outputs.mask_info = mask_info
        outputs.predicted_position_embeddings = (predicted_positional_embeddings[:, :-1]
                                                 if predicted_positional_embeddings is not None else None)
        outputs.target_position_embeddings = position_embeddings[:, 1:]
        outputs.positional_rnn_output = positional_rnn_output
        outputs.positional_rnn_state = positional_rnn_state
        return outputs

    def reconstruct(self, position_embeddings, action_embeddings, rnn_output, positional_rnn_state):
        # Get positional embeddings
        predicted_positional_embeddings, positional_rnn_output = \
            self.predict_positional_embeddings(position_embeddings, action_embeddings, rnn_output, positional_rnn_state)
        # Get reconstructions
        recons = self.decode(predicted_positional_embeddings, position_embeddings, rnn_output)
        # Mask if required
        if self.masker is not None:
            raise CodepathNotReadyError
            # noinspection PyUnreachableCode
            recons, mask_info = self.masker(recons)
        else:
            mask_info = Dict()
        # Done
        return recons, predicted_positional_embeddings, positional_rnn_output, mask_info

    def predict_positional_embeddings(self, position_embeddings, action_embeddings, rnn_output, positional_rnn_state):
        # Run the positional RNN (if required)
        if self.positional_rnn is not None:
            # The positional RNN takes in the current RNN output (at time t), the previous action_embedding
            # (at time t - 1) and the previous positional_embedding (at time t - 1) to generate the current
            # positional_embedding (at time t).
            predicted_positional_embeddings, positional_rnn_output, positional_rnn_state = \
                self.positional_rnn(position_embeddings, action_embeddings, rnn_output, positional_rnn_state)
        else:
            predicted_positional_embeddings, positional_rnn_output, positional_rnn_state = None, None, None
        return predicted_positional_embeddings, positional_rnn_output

    def decode(self, predicted_positional_embeddings, position_embeddings, rnn_output):
        # Decode to obtain reconstruction, which is a N(T - 1)AChw tensor. Why (T - 1), you ask?
        # Because the RNN output at time t (i.e. with input at time t), taken together with the position at time
        # t + 1 should decode to the state at time t + 1. This means that we should shift the time series
        # corresponding to the position embeddings one step to the future.
        # But first, we check if we should use the predicted positional embeddings
        if predicted_positional_embeddings is not None and self.decode_with_predicted_positional_encodings:
            # For this to work, check if both position embeddings and the predicted embeddings are normalized
            assert self.positional_encoder.normalize and self.positional_rnn.normalize
            # If required, detach the predicted positional embeddings. This is when we want the positional RNN to *not*
            # train from the reconstruction loss.
            recons = self.decoder(rnn_output[:, :-1],
                                  (predicted_positional_embeddings.detach()
                                   if self.detach_positional_encodings_before_decoding else
                                   predicted_positional_embeddings)[:, :-1])
        else:
            recons = self.decoder(rnn_output[:, :-1], position_embeddings[:, 1:])
        return recons

    def one_step_rollout(self, state_in, is_initial=False):
        if self.positional_rnn is not None:
            assert self.positional_rnn.normalize and self.positional_encoder.normalize
        # We proceed as following.
        #   1. Use the states, positions and actions at time t to get the world-state at time t + 1
        #   2. Use the world state at time t + 1 together with the positions and actions at time t to predict the
        #      position at t + 1.
        #   3. Use the predicted position at time t + 1 together with the world-state at time t + 1 to predict the
        #      local state at time t + 1.
        # Embed and encode actions, positions and states at time t. These are all N1AC tensors.
        action_embeddings, position_embeddings, state_encoding, state_action_encoding, *remaining_embeddings = \
            self.embed_and_encode(actions=state_in.actions, positions=state_in.positions, states=state_in.states,
                                  goals=state_in.goals, position_is_embedding=(not is_initial),
                                  goal_is_embedding=(not is_initial))
        goal_embeddings = remaining_embeddings[0] if len(remaining_embeddings) > 0 else None
        state_in.positions = position_embeddings
        state_in.goals = goal_embeddings
        assert action_embeddings.shape[1] == position_embeddings.shape[1] == state_encoding.shape[1] == 1
        # Reduce over agents to a N1C tensor
        rnn_input = self.aggregator(state_action_encoding)
        # Run through RNN (with the previous hidden state) to get the world state at time t + 1 (shape N1C)
        rnn_output, *rnn_state_next = self.rnn(rnn_input, *(state_in.rnn_state or ()))
        if self.positional_rnn is not None:
            # Run through the positional RNN
            position_embeddings_next, _, *positional_rnn_state_next = \
                self.positional_rnn(position_embeddings, action_embeddings, rnn_output,
                                    *(state_in.positional_rnn_state or ()))
        else:
            assert 'next_positions' in state_in
            position_embeddings_next = self.embed_positions(state_in.next_positions)
            positional_rnn_state_next = ()
        # Use the next positional embedding to predict the next state
        reconstructions = self.decoder(rnn_output, position_embeddings_next)
        if self.masker is not None:
            raise CodepathNotReadyError
            # noinspection PyUnreachableCode
            reconstructions, _ = self.masker(reconstructions)
        states_next = self.recon_to_state_processor(reconstructions, state_in)
        # Build output state
        state_out = Dict(actions=None, positions=position_embeddings_next, states=states_next,
                         rnn_state=rnn_state_next, positional_rnn_state=positional_rnn_state_next,
                         goals=goal_embeddings, recons=reconstructions, rnn_output=rnn_output)
        return state_out

    def register_recon_to_state_processor(self, fn):
        self.recon_to_state_processor = fn
        return self

    @contextmanager
    def disable_positional_rnn(self):
        prnn = self.positional_rnn
        self.positional_rnn = None
        yield
        self.positional_rnn = prnn


class MAWMWithRNN(MAWM):
    def __init__(self, *,
                 # Basic
                 state_size=(5, 5), num_in_state_channels=23, num_out_state_channels=21, polar_states=False,
                 num_agents=10,
                 # Positional Embeddings
                 positional_encoding_dim=16, num_position_dims=2, positional_encoding_max_frequency=1000,
                 normalize_positional_encodings=True, trainable_positional_encodings=False,
                 # Actions
                 num_actions=5, action_encoding_dim=16,
                 # Encoder
                 encoder_type='tower', encoder_capacity=128, repr_dim=128,
                 # Aggregator
                 aggregation_mode='sum', aggregation_keep_proba=0.5, aggregation_drop_full_trajectory=False,
                 aggregation_normalize_at_eval=True,
                 # Main RNN
                 rnn_capacity=128, rnn_type='LSTM', rnn_depth=1, rnn_kwargs=None,
                 # Decoder
                 decoder_type='residual', decoder_capacity=128, mask_spec=None,
                 # Positional RNN
                 positional_rnn_type=None, positional_rnn_capacity=128, positional_rnn_kwargs=None,
                 decode_with_predicted_positional_encodings=False, detach_positional_encodings_before_decoding=False):

        positional_encoder = enc.PositionalEncoding(positional_encoding_dim, num_position_dims,
                                                    positional_encoding_max_frequency,
                                                    normalize_positional_encodings,
                                                    trainable_positional_encodings)
        if action_encoding_dim is None or action_encoding_dim == 0:
            action_encoder = None
            action_encoding_dim = 0
        else:
            action_encoder = enc.ActionEncoding(num_actions, action_encoding_dim)
        if encoder_type == 'tower':
            encoder = enc.TowerEncoder(state_size, num_in_state_channels,
                                       positional_encoding_dim, encoder_capacity,
                                       repr_dim, polar_states)
        else:
            raise NotImplementedError
        if aggregation_mode == 'sum':
            aggregator = enc.SumAggregator()
            rnn_input_size = repr_dim + action_encoding_dim
        elif aggregation_mode == 'drop_sum':
            aggregator = enc.DropSumAggregator(keep_proba=aggregation_keep_proba,
                                               full_trajectory_drop=aggregation_drop_full_trajectory,
                                               normalize_at_eval=aggregation_normalize_at_eval)
            rnn_input_size = repr_dim + action_encoding_dim
        elif aggregation_mode == 'fold':
            aggregator = enc.FoldAggregator()
            rnn_input_size = (repr_dim * num_agents) + action_encoding_dim
        else:
            raise NotImplementedError
        if rnn_type == 'LSTM':
            rnn = nn.LSTM(rnn_input_size, rnn_capacity, batch_first=True, **(rnn_kwargs or {}))
        elif rnn_type == 'GRU':
            rnn = nn.GRU(rnn_input_size, rnn_capacity, batch_first=True, **(rnn_kwargs or {}))
        elif rnn_type == 'RIM':
            rnn = rims.MultiRIM(input_size=rnn_input_size, batch_first=True, **(rnn_kwargs or {}))
            rnn_capacity = rnn_kwargs['num_rims'] * rnn_kwargs['rim_hidden_size']
        elif rnn_type == 'Blocks':
            rnn = BlocksWrapper(ntokens=rnn_input_size, nhid=rnn_capacity, nout=rnn_capacity, batch_first=True,
                                **(rnn_kwargs or {}))
        elif rnn_type == 'RMC':
            rnn = rmc.RelationalMemory(input_size=rnn_input_size, **(rnn_kwargs or {}))
            rnn_capacity = rnn.hidden_size
        elif rnn_type == 'TTO':
            rnn = tto.TimeTravellingOracle(rnn_input_size, rnn_capacity, **(rnn_kwargs or {}))
        else:
            raise NotImplementedError
        if mask_spec is not None:
            raise CodepathNotReadyError
            # noinspection PyUnreachableCode
            masker = MaskedReconstructor(mask_spec, channel_dim=3)
            num_out_state_channels += masker.num_mask_channels
        else:
            masker = None
        if decoder_type == 'residual':
            decoder = dec.ResidualDecoder(state_size, num_out_state_channels,
                                          positional_encoding_dim, decoder_capacity,
                                          rnn_capacity, polar_states)
        else:
            raise NotImplementedError
        if positional_rnn_type is not None:
            positional_rnn = prnn.PositionalRNN(positional_encoding_dim=positional_encoding_dim,
                                                action_encoding_dim=action_encoding_dim,
                                                world_rnn_capacity=rnn_capacity,
                                                hidden_size=positional_rnn_capacity,
                                                rnn_type=positional_rnn_type,
                                                **(positional_rnn_kwargs or {}))
        else:
            positional_rnn = None
        # Init super
        super(MAWMWithRNN, self).__init__(encoder, action_encoder, positional_encoder, aggregator, rnn, decoder,
                                          positional_rnn, decode_with_predicted_positional_encodings,
                                          detach_positional_encodings_before_decoding, masker=masker)
