import torch
from torch import nn
from torch.nn import functional as F

from .layers import ConvNorm, LinearNorm
from .utils import get_mask_from_lengths
from ..beam_search import SequenceGenerator


class LocationLayer(nn.Module):
    def __init__(self, attention_n_filters, attention_kernel_size,
                 attention_dim):
        super(LocationLayer, self).__init__()
        padding = int((attention_kernel_size - 1) / 2)
        self.location_conv = ConvNorm(2, attention_n_filters,
                                      kernel_size=attention_kernel_size,
                                      padding=padding, bias=False, stride=1,
                                      dilation=1)
        self.location_dense = LinearNorm(attention_n_filters, attention_dim,
                                         bias=False, w_init_gain='tanh')

    def forward(self, attention_weights_cat):
        processed_attention = self.location_conv(attention_weights_cat)
        processed_attention = processed_attention.transpose(1, 2)
        processed_attention = self.location_dense(processed_attention)
        return processed_attention


class Attention(nn.Module):
    def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
                 attention_location_n_filters, attention_location_kernel_size):
        super(Attention, self).__init__()
        self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
                                      bias=False, w_init_gain='tanh')
        self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
                                       w_init_gain='tanh')
        self.v = LinearNorm(attention_dim, 1, bias=False)
        self.location_layer = LocationLayer(attention_location_n_filters,
                                            attention_location_kernel_size,
                                            attention_dim)
        self.score_mask_value = -float("inf")

    def get_alignment_energies(self, query, processed_memory,
                               attention_weights_cat):
        """
        PARAMS
        ------
        query: decoder output (batch, n_mel_channels * n_frames_per_step)
        processed_memory: processed encoder outputs (B, T_in, attention_dim)
        attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)

        RETURNS
        -------
        alignment (batch, max_time)
        """

        processed_query = self.query_layer(query.unsqueeze(1))
        processed_attention_weights = self.location_layer(attention_weights_cat)
        energies = self.v(torch.tanh(
            processed_query + processed_attention_weights + processed_memory))

        energies = energies.squeeze(-1)
        return energies

    def forward(self, attention_hidden_state, memory, processed_memory,
                attention_weights_cat, mask):
        """
        PARAMS
        ------
        attention_hidden_state: attention rnn last output
        memory: encoder outputs
        processed_memory: processed encoder outputs
        attention_weights_cat: previous and cummulative attention weights
        mask: binary mask for padded data
        """
        alignment = self.get_alignment_energies(
            attention_hidden_state, processed_memory, attention_weights_cat)

        if mask is not None:
            alignment.data.masked_fill_(mask, self.score_mask_value)

        attention_weights = F.softmax(alignment, dim=1)
        attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
        attention_context = attention_context.squeeze(1)

        return attention_context, attention_weights


class Encoder(nn.Module):
    """Encoder module:
        - Three 1-d convolution banks
        - Bidirectional LSTM
    """

    def __init__(self, hparams):
        super(Encoder, self).__init__()

        convolutions = []
        for _ in range(hparams['encoder_n_convolutions']):
            conv_layer = nn.Sequential(
                ConvNorm(hparams['encoder_embedding_dim'],
                         hparams['encoder_embedding_dim'],
                         kernel_size=hparams['encoder_kernel_size'], stride=1,
                         padding=int((hparams['encoder_kernel_size'] - 1) / 2),
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(hparams['encoder_embedding_dim']))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)

        self.lstm = nn.LSTM(hparams['encoder_embedding_dim'],
                            int(hparams['encoder_embedding_dim'] / 2), 1,
                            batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths):
        for conv in self.convolutions:
            x = F.relu(conv(x))

        x = x.transpose(1, 2)

        # pytorch tensor are not reversible, hence the conversion
        input_lengths = input_lengths.cpu().numpy()
        x = nn.utils.rnn.pack_padded_sequence(
            x, input_lengths, batch_first=True, enforce_sorted=False)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)

        return outputs

    def inference(self, x):
        for conv in self.convolutions:
            x = F.relu(conv(x))

        x = x.transpose(1, 2)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        return outputs


class Decoder(nn.Module):
    def __init__(self, hparams, out_dim):
        super(Decoder, self).__init__()
        self.hparams = hparams
        self.n_frames_per_step = hparams['n_frames_per_step']
        self.encoder_embedding_dim = hparams['encoder_embedding_dim']
        self.attention_rnn_dim = hparams['attention_rnn_dim']
        self.decoder_rnn_dim = hparams['decoder_rnn_dim']
        self.prenet_dim = hparams['prenet_dim']
        self.max_decoder_steps_ratio = hparams['max_decoder_steps_ratio']
        self.gate_threshold = hparams['gate_threshold']
        self.p_attention_dropout = hparams['p_attention_dropout']
        self.p_decoder_dropout = hparams['p_decoder_dropout']
        self.embedding = nn.Embedding(out_dim, hparams['prenet_dim'])
        self.tgt_dict_size = out_dim

        self.attention_rnn = nn.LSTMCell(
            hparams['prenet_dim'] + hparams['encoder_embedding_dim'],
            hparams['attention_rnn_dim'])

        self.attention_layer = Attention(
            hparams['attention_rnn_dim'], hparams['encoder_embedding_dim'],
            hparams['attention_dim'], hparams['attention_location_n_filters'],
            hparams['attention_location_kernel_size'])

        self.decoder_rnn = nn.LSTMCell(
            hparams['attention_rnn_dim'] + hparams['encoder_embedding_dim'],
            hparams['decoder_rnn_dim'], 1)

        self.linear_projection = LinearNorm(
            hparams['decoder_rnn_dim'] + hparams['encoder_embedding_dim'], out_dim)

        self.gate_layer = LinearNorm(
            hparams['decoder_rnn_dim'] + hparams['encoder_embedding_dim'], 1,
            bias=True, w_init_gain='sigmoid')

        self.infer_mode = hparams['infer_mode']

    def get_go_frame(self, memory):
        """ Gets all zeros frames to use as first decoder input
        PARAMS
        ------
        memory: decoder outputs

        RETURNS
        -------
        decoder_input: all zeros frames
        """
        B = memory.size(0)
        decoder_input = torch.ones([B]).long() * 2
        decoder_input = decoder_input.to(memory.device)
        return decoder_input

    def initialize_decoder_states(self, memory, mask):
        """ Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory
        PARAMS
        ------
        memory: Encoder outputs
        mask: Mask for padded data if training, expects None for inference
        """
        B = memory.size(0)
        MAX_TIME = memory.size(1)

        attention_hidden = memory.data.new(B, self.attention_rnn_dim).zero_()
        attention_cell = memory.data.new(B, self.attention_rnn_dim).zero_()

        decoder_hidden = memory.data.new(B, self.decoder_rnn_dim).zero_()
        decoder_cell = memory.data.new(B, self.decoder_rnn_dim).zero_()

        attention_weights = memory.data.new(B, MAX_TIME).zero_()
        attention_weights_cum = memory.data.new(B, MAX_TIME).zero_()
        attention_context = memory.data.new(B, self.encoder_embedding_dim).zero_()

        all_attention_weights = memory.data.new(B, 0, MAX_TIME).zero_()
        self.memory = memory
        self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask

        return {
            'attention_hidden': attention_hidden,
            'attention_cell': attention_cell,
            'decoder_hidden': decoder_hidden,
            'decoder_cell': decoder_cell,
            'attention_weights': attention_weights,
            'all_attention_weights': all_attention_weights,
            'attention_weights_cum': attention_weights_cum,
            'attention_context': attention_context
        }

    def parse_decoder_outputs(self, outputs, gate_outputs, alignments):
        """ Prepares decoder outputs for output
        PARAMS
        ------
        mel_outputs:
        gate_outputs: gate output energies
        alignments:

        RETURNS
        -------
        mel_outputs:
        gate_outpust: gate output energies
        alignments:
        """
        # (T_out, B) -> (B, T_out)
        alignments = torch.stack(alignments).transpose(0, 1)
        # (T_out, B) -> (B, T_out)
        gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
        gate_outputs = gate_outputs.contiguous()
        # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
        outputs = torch.stack(outputs).transpose(0, 1).contiguous()
        alignments = F.pad(alignments, [0, 0, 1, 0], value=0.001)
        return outputs, gate_outputs, alignments

    def decode(self, decoder_input, ds=None, memory=None):
        """ Decoder step using stored states, attention and memory
        PARAMS
        ------
        decoder_input: previous mel output

        RETURNS
        -------
        mel_output:
        gate_output: gate output energies
        attention_weights:
        """
        if ds is None:
            assert memory is not None
            ds = self.initialize_decoder_states(memory, mask=None)

        cell_input = torch.cat((decoder_input, ds['attention_context']), -1)
        ds['attention_hidden'], ds['attention_cell'] = self.attention_rnn(
            cell_input, (ds['attention_hidden'], ds['attention_cell']))
        ds['attention_hidden'] = F.dropout(
            ds['attention_hidden'], self.p_attention_dropout, self.training)

        attention_weights_cat = torch.cat(
            (ds['attention_weights'].unsqueeze(1),
             ds['attention_weights_cum'].unsqueeze(1)), dim=1)
        ds['attention_context'], ds['attention_weights'] = self.attention_layer(
            ds['attention_hidden'], self.memory, self.processed_memory,
            attention_weights_cat, self.mask)

        ds['attention_weights_cum'] = ds['attention_weights_cum'] + ds['attention_weights']
        decoder_input = torch.cat(
            (ds['attention_hidden'], ds['attention_context']), -1)
        ds['decoder_hidden'], ds['decoder_cell'] = self.decoder_rnn(
            decoder_input, (ds['decoder_hidden'], ds['decoder_cell']))
        ds['decoder_hidden'] = F.dropout(
            ds['decoder_hidden'], self.p_decoder_dropout, self.training)

        decoder_hidden_attention_context = torch.cat(
            (ds['decoder_hidden'], ds['attention_context']), dim=1)
        decoder_output = self.linear_projection(
            decoder_hidden_attention_context)

        gate_prediction = self.gate_layer(decoder_hidden_attention_context)
        ds['all_attention_weights'] = torch.cat([ds['all_attention_weights'], ds['attention_weights'][:, None]], 1)
        return decoder_output, gate_prediction, ds['attention_weights'], ds

    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_inputs = decoder_inputs.transpose(0, 1)
        decoder_inputs = self.embedding(decoder_inputs)  # [T+1, B, H]
        decoder_states = self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths))

        decoder_output, gate_outputs, alignments = [], [], []
        while len(decoder_output) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(decoder_output)]
            mel_output, gate_output, attention_weights, decoder_states = self.decode(decoder_input, decoder_states)
            decoder_output += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze(1)]
            alignments += [attention_weights]

        decoder_output, gate_outputs, alignments = self.parse_decoder_outputs(
            decoder_output, gate_outputs, alignments)

        return decoder_output, gate_outputs, alignments

    def inference(self, memory, lm=None, lm_weight=0.5, lang_ids=None):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        decoder_states = self.initialize_decoder_states(memory, mask=None)

        decoder_outputs, gate_outputs, alignments = [], [], []
        while True:
            decoder_output, gate_output, alignment, decoder_states = self.decode(
                self.embedding(decoder_input), decoder_states)
            decoder_output = F.softmax(decoder_output, -1)
            if lm is not None and lm_weight > 0:
                if len(decoder_outputs) == 0:
                    lm_inp = decoder_input[None, :]
                else:
                    lm_inp = F.pad(torch.stack(decoder_outputs, 1), [1, 0], value=2)
                lm_output, decoder_states['lm'] = \
                    lm.forward_one_step(lang_ids, lm_inp, cache=decoder_states.get('lm'))
                lm_output = F.softmax(lm_output, -1)
                # decoder_output[decoder_output < 0.01] = 0
                # lm_output[lm_output < 0.01] = 0
                # print(">>>", decoder_output, lm_output)
                decoder_output = decoder_output * (1 - lm_weight) + lm_output * lm_weight
            if self.infer_mode == 'argmax':
                decoder_output = decoder_output.argmax(-1)
            elif self.infer_mode == 'sampling':
                decoder_output = torch.multinomial(decoder_output, 1)[:, 0]
            decoder_outputs += [decoder_output]  # [T, B]
            gate_outputs += [gate_output]
            alignments += [alignment]
            if decoder_output[0].item() == 1:
                break
            elif len(decoder_outputs) >= memory.shape[1] / self.max_decoder_steps_ratio:
                # print("Warning! Reached max decoder steps")
                break
            decoder_input = decoder_output

        decoder_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            decoder_outputs, gate_outputs, alignments)

        return decoder_outputs, gate_outputs, alignments

    def beam_search(self, encoder_outputs, lm=None, lm_weight=None, lang_ids=None):
        generator = SequenceGenerator(self, self.tgt_dict_size, **self.hparams['beam_search_config'])
        out = generator.generate(encoder_outputs, lm=lm, lm_weight=lm_weight, lang_ids=lang_ids)
        out = out[0][0]['tokens'][None, ...]
        return out, None, None


class RNNTranslator(nn.Module):
    def __init__(self, src_dict_size, tgt_dict_size, hparams):
        super(RNNTranslator, self).__init__()
        self.hp = hparams
        self.n_frames_per_step = hparams['n_frames_per_step']
        if src_dict_size > 0:
            self.embedding = nn.Embedding(src_dict_size, hparams['symbols_embedding_dim'])
        else:
            self.embedding = nn.Linear(80, hparams['symbols_embedding_dim'])
        if self.hp['use_lang_embed']:
            self.lang_embed = nn.Embedding(10, hparams['symbols_embedding_dim'])
        self.encoder = Encoder(hparams)
        self.decoder = Decoder(hparams, tgt_dict_size)

    def forward(self, lang, src, src_lengths, tgt, _):
        src_lengths = src_lengths.data

        embedded_inputs = self.embedding(src).transpose(1, 2)  # [B, H, T]
        if self.hp['use_lang_embed']:
            embedded_inputs = embedded_inputs + self.lang_embed(lang)[..., None]
        encoder_outputs = self.encoder(embedded_inputs, src_lengths)

        outputs, gate_outputs, alignments = self.decoder(encoder_outputs, tgt, memory_lengths=src_lengths)

        return outputs, gate_outputs, alignments

    def inference(self, lang, src, lm=None, lm_weight=0.5):
        embedded_inputs = self.embedding(src).transpose(1, 2)
        if self.hp['use_lang_embed']:
            embedded_inputs = embedded_inputs + self.lang_embed(lang)[..., None]

        encoder_outputs = self.encoder.inference(embedded_inputs)
        if self.decoder.infer_mode == 'beam':
            outputs, gate_outputs, alignments = self.decoder.beam_search(
                encoder_outputs, lm, lm_weight=lm_weight, lang_ids=lang)
        else:
            outputs, gate_outputs, alignments = self.decoder.inference(
                encoder_outputs, lm, lm_weight=lm_weight, lang_ids=lang)
        return outputs, gate_outputs, alignments
