import heapq
import tensorflow as tf
import numpy as np

from vat.vat_model import create_masks

###########################################
# Beam search decoding                    #
###########################################
class Beam(object):
    def __init__(self, beam_size):
        self.heap = list()
        self.beam_size = beam_size

    def add(self, prob, complete, prefix):
        heapq.heappush(self.heap, (prob, complete, prefix))
        if len(self.heap) > self.beam_size:
            heapq.heappop(self.heap)

    def __iter__(self):
        return iter(self.heap)


def evaluate_beam_search(transformer, tokenizer, inp_sentence, beam_size, clip_len, variation=False, is_encoded=False):
    # variation=False means to produce mean representation (Test setting)
    # is_encoded=False means sentence is string
    prev_beam = Beam(beam_size)
    prev_beam.add(1.0, False, [tokenizer.vocab_size])  # start_token_en

    # prepare encoder input (input sentence)
    if not is_encoded:
        inp_sentence = [tokenizer.vocab_size] + tokenizer.encode(inp_sentence) + [tokenizer.vocab_size + 1]
    encoder_input = tf.expand_dims(inp_sentence, 0)
    while True:
        curr_beam = Beam(beam_size)

        # Add complete sentences to the beam or add more words to incomplete sentences
        for (prefix_prob, complete, prefix) in prev_beam:
            if complete:
                curr_beam.add(prefix_prob, True, prefix)
            else:
                # predict next word for given incomplete sentence and input sentence
                decoder_input = tf.expand_dims(prefix, 0)
                enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, decoder_input)
                predicts, _, _, _ = transformer(encoder_input, decoder_input, variation, enc_padding_mask,
                                                      combined_mask, dec_padding_mask)
                predicts = np.array(predicts[:, -1:, :]).squeeze()  # extract last token

                # softmax
                predicts = predicts - np.max(predicts)
                predicts = np.exp(predicts)
                predicts = predicts / np.sum(predicts)

                indices = predicts.argsort()[-beam_size:]  # top predictions
                probs = predicts[indices]
                for (next_prob, next_idx) in zip(probs, indices):
                    if next_idx == tokenizer.vocab_size + 1:  # end token
                        curr_beam.add(prefix_prob * next_prob, True, prefix)
                    else:
                        curr_beam.add(prefix_prob * next_prob, False, prefix + [next_idx])

        (best_prob, best_complete, best_prefix) = max(curr_beam)
        if best_complete or len(best_prefix) - 1 == clip_len:
            return best_prefix[1:]

        prev_beam = curr_beam

###########################################
# Translation                             #
# Using beamsearch                        #
###########################################
def translate(transformer, tokenizer, sentence, max_len=50, beam_size=5, variation=False):
    print(f'Input: {sentence}')
    result = evaluate_beam_search(transformer, tokenizer, sentence, beam_size=beam_size, clip_len=max_len,
                                         variation=variation)

    predicted_sentence = tokenizer.decode([i for i in result if i < tokenizer.vocab_size])
    print(f'Beam search translation: {predicted_sentence}')
