import torch
import torch.nn as nn
import torch.nn.functional as F

from .discriminative_listener import DiscriminativeListener
from ..networks import choose_architecture, layer_init, hasnan, BetaVAE
from ..utils import StraightThroughGumbelSoftmaxLayer


class AttentionLSTMCNNListener(DiscriminativeListener):
    def __init__(self,kwargs, obs_shape, vocab_size=100, max_sentence_length=10, agent_id='l0', logger=None):
        """
        :param obs_shape: tuple defining the shape of the stimulus following `(nbr_distractors+1, nbr_stimulus, *stimulus_shape)`
                          where, by default, `nbr_distractors=1` and `nbr_stimulus=1` (static stimuli). 
        :param vocab_size: int defining the size of the vocabulary of the language.
        :param max_sentence_length: int defining the maximal length of each sentence the speaker can utter.
        :param agent_id: str defining the ID of the agent over the population.
        :param logger: None or somee kind of logger able to accumulate statistics per agent.
        """
        super(AttentionLSTMCNNListener, self).__init__(obs_shape, vocab_size, max_sentence_length, agent_id, logger, kwargs)
        self.use_sentences_one_hot_vectors = True 
        self.kwargs = kwargs 

        cnn_input_shape = self.obs_shape[2:]
        MHDPANbrHead=4
        MHDPANbrRecUpdate=1
        MHDPANbrMLPUnit=512
        MHDPAInteractionDim=128
        if 'mhdpa_nbr_head' in self.kwargs: MHDPANbrHead = self.kwargs['mhdpa_nbr_head']
        if 'mhdpa_nbr_rec_update' in self.kwargs: MHDPANbrRecUpdate = self.kwargs['mhdpa_nbr_rec_update']
        if 'mhdpa_nbr_mlp_unit' in self.kwargs: MHDPANbrMLPUnit = self.kwargs['mhdpa_nbr_mlp_unit']
        if 'mhdpa_interaction_dim' in self.kwargs: MHDPAInteractionDim = self.kwargs['mhdpa_interaction_dim']

        if 'cnn_encoder' in self.kwargs:
            self.cnn_encoder = self.kwargs['cnn_encoder']
        else:
            self.cnn_encoder = choose_architecture(architecture=self.kwargs['architecture'],
                                                   kwargs=self.kwargs,
                                                   input_shape=cnn_input_shape,
                                                   feature_dim=self.kwargs['cnn_encoder_feature_dim'],
                                                   nbr_channels_list=self.kwargs['cnn_encoder_channels'],
                                                   kernels=self.kwargs['cnn_encoder_kernels'],
                                                   strides=self.kwargs['cnn_encoder_strides'],
                                                   paddings=self.kwargs['cnn_encoder_paddings'],
                                                   fc_hidden_units_list=self.kwargs['cnn_encoder_fc_hidden_units'],
                                                   dropout=self.kwargs['dropout_prob'],
                                                   MHDPANbrHead=MHDPANbrHead,
                                                   MHDPANbrRecUpdate=MHDPANbrRecUpdate,
                                                   MHDPANbrMLPUnit=MHDPANbrMLPUnit,
                                                   MHDPAInteractionDim=MHDPAInteractionDim)

        self.use_feat_converter = self.kwargs['use_feat_converter'] if 'use_feat_converter' in self.kwargs else False 
        if self.use_feat_converter:
            self.feat_converter_input = self.cnn_encoder.get_feature_shape()


        if 'BetaVAE' in self.kwargs['architecture'] or 'MONet' in self.kwargs['architecture']:
            self.VAE_losses = list()
            self.compactness_losses = list()
            self.buffer_cnn_output_dict = dict()
            
            if not('cnn_encoder' in self.kwargs):
                self.cnn_encoder = choose_architecture(architecture=self.kwargs['architecture'],
                                                       kwargs=self.kwargs,
                                                       input_shape=cnn_input_shape,
                                                       feature_dim=self.kwargs['cnn_encoder_feature_dim'],
                                                       dropout=self.kwargs['dropout_prob'])
            # N.B: with a VAE, we want to learn the weights in any case:
            if 'agent_learning' in self.kwargs:
                assert('transfer_learning' not in self.kwargs['agent_learning'])
            
            self.vae_detached_featout = False
            if self.kwargs['vae_detached_featout']:
                self.vae_detached_featout = True

            self.VAE = self.cnn_encoder

            self.use_feat_converter = True
            self.feat_converter_input = self.cnn_encoder.latent_dim
        else:
            if 'agent_learning' in self.kwargs and 'transfer_learning' in self.kwargs['agent_learning']:
                self.cnn_encoder.detach_conv_maps = True

        self.encoder_feature_shape = self.cnn_encoder.get_feature_shape()
        if self.use_feat_converter:
            self.featout_converter = []
            self.featout_converter.append(nn.Linear(self.feat_converter_input, self.kwargs['cnn_encoder_feature_dim']*2))
            self.featout_converter.append(nn.ReLU())
            self.featout_converter.append(nn.Linear(self.kwargs['cnn_encoder_feature_dim']*2, self.kwargs['feat_converter_output_size'])) 
            self.featout_converter.append(nn.ReLU())
            self.featout_converter =  nn.Sequential(*self.featout_converter)
            self.encoder_feature_shape = self.kwargs['feat_converter_output_size']
        
        self.cnn_encoder_normalization = nn.BatchNorm1d(num_features=self.encoder_feature_shape)
        
        temporal_encoder_input_dim = self.cnn_encoder.get_feature_shape()
        if self.kwargs['temporal_encoder_nbr_rnn_layers'] > 0:
            self.temporal_feature_encoder = layer_init(nn.LSTM(input_size=temporal_encoder_input_dim,
                                              hidden_size=self.kwargs['temporal_encoder_nbr_hidden_units'],
                                              num_layers=self.kwargs['temporal_encoder_nbr_rnn_layers'],
                                              batch_first=True,
                                              dropout=self.kwargs['dropout_prob'],
                                              bidirectional=False))
        else:
            self.temporal_feature_encoder = None
            print("WARNING: Symbol processing :: the number of hidden units is being reparameterized to fit to convolutional features.")
            self.kwargs['temporal_encoder_nbr_hidden_units'] = self.kwargs['nbr_stimulus']*self.encoder_feature_shape
            #self.kwargs['symbol_processing_nbr_hidden_units'] = self.kwargs['temporal_encoder_nbr_hidden_units']


        self.normalization = nn.BatchNorm1d(num_features=self.kwargs['temporal_encoder_nbr_hidden_units'])
        #self.normalization = nn.LayerNorm(normalized_shape=self.kwargs['temporal_encoder_nbr_hidden_units'])
        
        ## Textual Encoder:        
        self.textual_embedding_size = self.kwargs['symbol_embedding_size']
        self.textual_embedder = nn.Sequential(
            nn.Linear(self.vocab_size, self.textual_embedding_size, bias=False),
            nn.Dropout( p=self.kwargs['embedding_dropout_prob'])
            )

        self.textual_encoder_input_dim = self.textual_embedding_size
        self.textual_encoder_hidden_size = self.kwargs['textual_encoder_nbr_hidden_units']
        self.textual_encoder = nn.LSTM(input_size=self.textual_encoder_input_dim,
                                      hidden_size=self.textual_encoder_hidden_size, 
                                      num_layers=self.kwargs['textual_encoder_nbr_rnn_layers'],
                                      batch_first=True,
                                      dropout=self.kwargs['dropout_prob'],
                                      bidirectional=False)
        self.textual_encoder_learnable_initial_state = nn.Parameter(
            torch.zeros(1,1,self.textual_encoder_hidden_size)
        )
        

        ## Multi-modal Attention:
        self.decoder_nbr_steps = self.kwargs['visual_decoder_nbr_steps']
        self.visual_decoder_hidden_size = self.kwargs['visual_decoder_nbr_hidden_units']
        
        # Attention Model:
        self.attention_model_input_size = self.textual_encoder_hidden_size+self.visual_decoder_hidden_size
        self.attention_model_hidden_size = self.kwargs['attention_model_hidden_size']
        self.attention_model = nn.Sequential(
            nn.Linear(
                self.attention_model_input_size, 
                self.attention_model_hidden_size, 
                bias=True
            ),
            nn.ReLU(inplace=True),
            nn.Linear(
                self.attention_model_hidden_size, 
                self.attention_model_hidden_size, 
                bias=True
            ),
            nn.ReLU(inplace=True),
            nn.Linear(
                self.attention_model_hidden_size, 
                1, 
                bias=True
            )
        )
        
        self.textual_context_dim = self.textual_encoder_hidden_size

        self.st_gs = StraightThroughGumbelSoftmaxLayer(
                inv_tau0=self.kwargs['attention_st_gs_inv_tau0'], 
                input_dim=self.textual_encoder_hidden_size
        )

        ##

        ## Visual Decoder:
        self.visual_decoder_input_dim = self.textual_context_dim+self.encoder_feature_shape
        self.visual_decoder = nn.LSTM(input_size=self.visual_decoder_input_dim,
                                      hidden_size=self.visual_decoder_hidden_size, 
                                      num_layers=self.kwargs['visual_decoder_nbr_rnn_layers'],
                                      batch_first=True,
                                      dropout=self.kwargs['dropout_prob'],
                                      bidirectional=False)
        self.visual_decoder_learnable_initial_state = nn.Parameter(
                torch.zeros(1,1,self.visual_decoder_hidden_size)
        )

        self.visual_decoder_mlp = nn.Sequential(
            nn.Linear(
                self.visual_decoder_hidden_size, 
                self.encoder_feature_shape, 
                bias=True
            ),
            nn.Dropout( p=self.kwargs['visual_decoder_mlp_dropout_prob'])
        )
        self.visual_decoder_mlp_learnable_initial_output = nn.Parameter(
                torch.zeros(1, self.encoder_feature_shape)
        )

        if self.textual_context_dim != self.visual_decoder_hidden_size:
            self.textual_context2decoder_converter = nn.Linear(
                self.textual_context_dim,
                self.visual_decoder_hidden_size
            )
        else:
            self.textual_context2decoder_converter = None 

        ##
        
        '''
        self.not_target_logits_per_token = nn.Parameter(torch.ones((1, self.kwargs['max_sentence_length'], 1)))
        '''

        self.projection_normalization = None #nn.BatchNorm1d(num_features=self.kwargs['max_sentence_length']*self.kwargs['symbol_processing_nbr_hidden_units'])

        self.reset_weights()

    def reset_weights(self):
        self.textual_embedder.apply(layer_init)
        self.textual_encoder.apply(layer_init)
        self.attention_model.apply(layer_init)
        self.st_gs.apply(layer_init)
        self.visual_decoder.apply(layer_init)
        self.visual_decoder_mlp.apply(layer_init)
        if self.textual_context2decoder_converter is not None:
            self.textual_context2decoder_converter.apply(layer_init)
        
        self.embedding_tf_final_outputs = None
        self._reset_rnn_states()

    def _tidyup(self):
        self.embedding_tf_final_outputs = None

        if isinstance(self.cnn_encoder, BetaVAE):
            self.VAE_losses = list()
            self.compactness_losses.clear()
            self.buffer_cnn_output_dict = dict()

    def _compute_tau(self, tau0, h):
        '''
        invtau = 1.0 / (self.tau_fc(h).squeeze() + tau0)
        return invtau
        '''
        raise NotImplementedError

    def _sense(self, experiences, sentences=None):
        r"""
        Infers features from the experiences that have been provided.

        :param experiences: Tensor of shape `(batch_size, *self.obs_shape)`. 
                        Make sure to shuffle the stimuli so that the order does not give away the target. 
        :param sentences: None or Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of (potentially one-hot-encoded) symbols.
        
        :returns:
            features: Tensor of shape `(batch_size, -1, feature_dim).
        
        """
        batch_size = experiences.size(0)
        nbr_distractors_po = experiences.size(1)
        experiences = experiences.view(-1, *(experiences.size()[3:]))
        features = []
        total_size = experiences.size(0)
        mini_batch_size = min(self.kwargs['cnn_encoder_mini_batch_size'], total_size)
        for stin in torch.split(experiences, split_size_or_sections=mini_batch_size, dim=0):
            if isinstance(self.cnn_encoder, BetaVAE):
                cnn_output_dict  = self.cnn_encoder.compute_loss(stin)
                if 'VAE_loss' in cnn_output_dict:
                    self.VAE_losses.append(cnn_output_dict['VAE_loss'])
                
                if hasattr(self.cnn_encoder, 'compactness_losses') and self.cnn_encoder.compactness_losses is not None:
                    self.compactness_losses.append(self.cnn_encoder.compactness_losses.cpu())
                
                for key in cnn_output_dict:
                    if key not in self.buffer_cnn_output_dict:
                        self.buffer_cnn_output_dict[key] = list()
                    self.buffer_cnn_output_dict[key].append(cnn_output_dict[key].cpu())

                if self.kwargs['vae_use_mu_value']:
                    featout = self.cnn_encoder.mu 
                else:
                    featout = self.cnn_encoder.z

                if self.vae_detached_featout:
                    featout = featout.detach()

                featout = self.featout_converter(featout)
            else:
                featout = self.cnn_encoder(stin)
                if self.use_feat_converter:
                    featout = self.featout_converter(featout)

            features.append(featout)
        
        self.features = self.cnn_encoder_normalization(torch.cat(features, dim=0))
        
        self.features = self.features.view(batch_size, nbr_distractors_po, self.config['nbr_stimulus'], -1)
        # (batch_size, nbr_distractors+1 / ? (descriptive mode depends on the role of the agent), nbr_stimulus, feature_dim)
        
        if isinstance(self.cnn_encoder, BetaVAE):
            self.VAE_losses = torch.cat(self.VAE_losses).contiguous()#.view((batch_size,-1)).mean(dim=-1)
            
            for key in self.buffer_cnn_output_dict:
                self.log_dict[key] = torch.cat(self.buffer_cnn_output_dict[key]).mean()

            self.log_dict['kl_capacity'] = torch.Tensor([100.0*self.cnn_encoder.EncodingCapacity/self.cnn_encoder.maxEncodingCapacity])
            if len(self.compactness_losses):
                self.log_dict['unsup_compactness_loss'] = torch.cat(self.compactness_losses).mean()

        return self.features 

    def _reason(self, sentences, features):
        """
        Reasons about the features and sentences to yield the target-prediction logits.
        
        :param sentences:   Tensor of shape `(batch_size, max_sentence_length, vocab_size)` 
                            containing the padded sequence of (potentially one-hot-encoded) symbols.
                            NOTE: max_sentence_length may be different from self.max_sentence_lenght 
                            as the padding is padding by batch and only care about the maximal 
                            sentence length of said batch.
        :param features: Tensor of shape `(batch_size, *self.obs_shape[:2], feature_dim)`.
        
        :returns:
            - decision_logits: Tensor of shape `(batch_size, self.obs_shape[1])` containing the target-prediction logits.
            - temporal features: Tensor of shape `(batch_size, (nbr_distractors+1)*temporal_feature_dim)`.
        """
        batch_size = features.size(0)
        nbr_distractors_po = features.size(1)        
        # (batch_size, nbr_distractors+1, nbr_stimulus, feature_dim)
        # Forward pass:
        if self.temporal_feature_encoder: 
            features = features.view(-1, *(features.size()[2:]))
            # (batch_size*(nbr_distractors+1), nbr_stimulus, kwargs['cnn_encoder_feature_dim'])
            rnn_outputs = []
            total_size = features.size(0)
            mini_batch_size = min(self.kwargs['temporal_encoder_mini_batch_size'], total_size)
            for featin in torch.split(features, split_size_or_sections=mini_batch_size, dim=0):
                outputs, _ = self.temporal_feature_encoder(featin)
                rnn_outputs.append( outputs)
            outputs = torch.cat(rnn_outputs, dim=0)
            outputs = outputs.view(batch_size, *(self.obs_shape[:2]), -1)
            
            # Caring only about the final output:
            embedding_tf_final_outputs = outputs[:,:,-1,:].contiguous()
            # (batch_size, (nbr_distractors+1), kwargs['temporal_encoder_nbr_hidden_units'])
            self.embedding_tf_final_outputs = self.normalization(embedding_tf_final_outputs.reshape((-1, self.kwargs['temporal_encoder_nbr_hidden_units'])))
            self.embedding_tf_final_outputs = self.embedding_tf_final_outputs.reshape(batch_size, nbr_distractors_po, -1)
            # (batch_size, (nbr_distractors+1), kwargs['temporal_encoder_nbr_hidden_units'])
        else:
            self.embedding_tf_final_outputs = self.normalization(features.reshape((-1, self.kwargs['temporal_encoder_nbr_hidden_units'])))
            self.embedding_tf_final_outputs = self.embedding_tf_final_outputs.reshape((batch_size, nbr_distractors_po, -1))
            # (batch_size, (nbr_distractors+1), kwargs['temporal_encoder_nbr_hidden_units'])

        # Textual Embedding:
        # (batch_size, max_sentence_length, self.vocab_size)
        sentences_length = sentences.shape[1]
        sentences = sentences.view((-1, self.vocab_size))
        embedded_symbols = self.textual_embedder(sentences) 
        # (batch_size*max_sentence_length, self.kwargs['symbol_embedding_size'])
        embedded_sentences = embedded_symbols.view((batch_size, -1, self.kwargs['symbol_embedding_size']))
        # (batch_size, max_sentence_length, textual_embedding_size)
        
        # Textual Encoding:
        init_textual_encoder_state = self.textual_encoder_learnable_initial_state.expand(
            self.kwargs['textual_encoder_nbr_rnn_layers'],
            batch_size, 
            -1
        )
        # (batch_size=1, nbr_hidden_layers*num_directions=1, self.textual_encoder_hidden_size)
        textual_encoder_state = (
                init_textual_encoder_state.contiguous(), 
                torch.zeros_like(init_textual_encoder_state)
            )
        textual_encoder_outputs, next_textual_encoder_states = self.textual_encoder(
            embedded_sentences, 
            textual_encoder_state
        )
        # (batch_size, sentences_length, self.textual_encoder_hidden_size)
        # (hidden_layer*num_directions, batch_size, self.textual_encoder_hidden_size)
        
        # Batch Normalization:
        if self.projection_normalization is not None:
            textual_encoder_outputs = self.projection_normalization(textual_encoder_outputs.reshape((batch_size, -1)))
            textual_encoder_outputs = textual_encoder_outputs.reshape((batch_size, -1, self.textual_encoder_hidden_size))

        # Decoder initial state, as seen in (Eq 10), after full focus, if any...
        init_visual_decoder_state = self.visual_decoder_learnable_initial_state.expand(
            self.kwargs['visual_decoder_nbr_rnn_layers'],
            batch_size, 
            -1
        )
        # (hidden_layer*num_directions=1, batch_size=1, self.symbol_processing_hidden_size)
        visual_decoder_state = (
            init_visual_decoder_state.contiguous(),
            torch.zeros_like(init_visual_decoder_state)
        )

        visual_decoder_output = self.visual_decoder_learnable_initial_state.expand(
            batch_size, 
            -1,
            -1
        )
        # (batch_size, 1, visual_decoder_hidden_size)
        
        visual_decoder_mlp_output = self.visual_decoder_mlp_learnable_initial_output.expand(
            batch_size,
            -1
        )

        decision_logits = []
        per_step_visual_decoder_mlp_outputs = []
        for timestep in range(self.decoder_nbr_steps):
            # (batch_size, 1, visual_decoder_hidden_size)
            visual_decoder_hidden_state = visual_decoder_output.expand(
                -1,
                sentences_length,
                -1
            )
            # (batch_size, sentence_length, visual_decoder_hidden_size)
            attention_model_inputs = torch.cat([textual_encoder_outputs, visual_decoder_hidden_state], dim=-1)
            # (batch_size, sentence_length, textual_encoder_hidden_size+visual_decoder_hidden_size) 
            # (Eq 6 (and below)): Attention Model + Softmax
            att = self.attention_model(attention_model_inputs).reshape(batch_size, sentences_length).softmax(dim=-1)
            # ( batch_size, sentence_length)
            
            # Soft or Hard Attention via Straight-Through Gumbel-Softmax:
            if not self.kwargs['attention_listener_soft_attention']:
                att = self.st_gs(logits=att, param=visual_decoder_hidden_state)
            # (batch_size, sentence_length)
            
            # (Eq 5) Context computation:
            context = (att.unsqueeze(-1) * textual_encoder_outputs).sum(1) 
            # (batch_size, context_dim=textual_encoder_hidden_size)

            ## Decoder:
            # (Eq 10) Forward pass through the decoder:
            visual_decoder_input = torch.cat(
                [
                    context.unsqueeze(1), 
                    visual_decoder_mlp_output.unsqueeze(1)
                ], 
                dim=-1
            )
            # (batch_size, seq_len=1, textual_context_dim(=textual_encoder_hidden_size)+visual_decoder_mlp_output_size=self.encoded_feature_shape)
            visual_decoder_outputs, next_visual_decoder_states = self.visual_decoder(
                visual_decoder_input, 
                visual_decoder_state
            )
            # (batch_size, 1, visual_decoder_hidden_size)
            # (hidden_layer*num_directions, batch_size, visual_decoder_hidden_size)
            # (Eq 11 - listener)
            # Input: (batch_size, visual_decoder_hidden_size)
            visual_decoder_mlp_output = self.visual_decoder_mlp(visual_decoder_outputs.reshape(batch_size,-1))
            # (batch_size, self.encoded_feature_shape)
            per_step_visual_decoder_mlp_outputs.append(visual_decoder_mlp_output)

            ## Bookkeeping:
            visual_decoder_output = visual_decoder_outputs
            # (batch_size, 1, visual_decoder_hidden_size)
            # Visual Decoder:
            visual_decoder_state = next_visual_decoder_states
            # Visual Decoder's Full Focus Scheme:
            if self.kwargs['attention_full_focus_scheme']:
                if self.textual_context2decoder_converter is not None:
                    context = self.textual_context2decoder_converter(context.reshape(-1, self.textual_context_dim))
                context = context.reshape(batch_size,1,-1)
                # (batch_size, 1, visual_decoder_hidden_size)
                # Assumes LSTMs...
                full_focus_visual_decoder_state = visual_decoder_state[0]*context
                decoder_state =(
                    full_focus_visual_decoder_state,
                    visual_decoder_state[1]
                )
            else:
                decoder_state =(
                    visual_decoder_state[0],
                    visual_decoder_state[1]
                )

            ## Visual Feature Projection:
            stimuli_features = self.embedding_tf_final_outputs
            # (batch_size, (nbr_distractors+1), encoded_feature_shape)
            nbr_distractors_po = self.embedding_tf_final_outputs.shape[1]
            # Summing on all the previous visual decoder output:
            visually_projected_textual_features = torch.stack(
                per_step_visual_decoder_mlp_outputs).sum(dim=0).unsqueeze(1
            ).expand(
                -1,
                nbr_distractors_po,
                -1
            )
            # (batch_size, (nbr_distractros+1), self.encoded_feature_shape)
            decision_logits_until_timestep = (visually_projected_textual_features*stimuli_features).sum(-1)
            # ( batch_size, (nbr_distractors+1))
            decision_logits.append(decision_logits_until_timestep.unsqueeze(1))
            # (batch_size, 1, (nbr_distractors+1) )
        decision_logits = torch.cat(decision_logits, dim=1)
        # (batch_size, nbr_decoder_steps, (nbr_distractors+1))           

        #TODO: why would this be needed already?? Apparently in case of descriptive mode, cf obverter...
        '''
        not_target_logit = self.not_target_logits_per_token.repeat(batch_size, 1, 1)
        if decision_logits.is_cuda: not_target_logit = not_target_logit.cuda()
        decision_logits = torch.cat([decision_logits, not_target_logit], dim=-1 )
        # (batch_size, (nbr_distractors+1) )
        '''

        return decision_logits, self.embedding_tf_final_outputs


    def _utter(self, features, sentences):
        """
        Reasons about the features and the listened sentences to yield the sentences to utter back.
        
        :param features: Tensor of shape `(batch_size, *self.obs_shape[:2], feature_dim)`.
        :param sentences: Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of (potentially one-hot-encoded) symbols.
        
        :returns:
            - logits: Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of logits.
            - sentences: Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of one-hot-encoded symbols.
            - temporal features: Tensor of shape `(batch_size, (nbr_distractors+1)*temporal_feature_dim)`.
        """

        """
        Reasons about the features and the listened sentences, if multi_round, to yield the sentences to utter back.
        
        :param features: Tensor of shape `(batch_size, *self.obs_shape[:2], feature_dim)`.
        :param sentences: None, or Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of (potentially one-hot-encoded) symbols.
        
        :returns:
            - word indices: Tensor of shape `(batch_size, max_sentence_length, 1)` of type `long` containing the indices of the words that make up the sentences.
            - logits: Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of logits.
            - sentences: Tensor of shape `(batch_size, max_sentence_length, vocab_size)` containing the padded sequence of one-hot-encoded symbols.
        """
        raise NotImplementedError
