import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty):
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty 
        self.num_beams = num_beams # beam size
        self.beams = [] 
        self.worst_score = 1e9 

    def __len__(self):
        return len(self.beams)

    def add(self, hyp, sum_logprobs, his_scores, his_feats):
        score = sum_logprobs / len(hyp) ** self.length_penalty 
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp, his_scores, his_feats))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _, _, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len):
        if len(self) < self.num_beams:
            return False
        else:
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            ret = self.worst_score >= cur_score
            return ret

# batch_size = 3
# num_beams = 2
# vocab_size = 8
# cur_len = 1
# embedding_size = 300
# hidden_size = 100
# max_length = 10
# sos_token_id = 0
# eos_token_id = 1
# pad_token_id = 2
# decoder = DecoderRNN(embedding_size, hidden_size, vocab_size)


def beam_search(model, input_ids, pre_kv_list, batch_size, num_beams,pre_len, max_length, global_only_image,temperature, length_penalty, bos_token_id,eos_token_id,pad_token_id):
    beam_scores = torch.zeros((batch_size, num_beams)).to(input_ids.device) # 定义scores向量，保存累加的log_probs
    beam_scores[:, 1:] = -1e9 
    beam_scores = beam_scores.view(-1) 
    done = [False for _ in range(batch_size)] 
    generated_hyps = [
        BeamHypotheses(num_beams, max_length, length_penalty=length_penalty)
            for _ in range(batch_size)
    ] 
    cur_len = 0 
    vocab_size = model.embed.word_embeddings.weight.shape[0]
    pre_kv_list = model.repeat_cache(num_beams, pre_kv_list)
    
    # record history scores and grid references
    his_scores = torch.zeros_like(input_ids).to(torch.float)
    his_feats = torch.zeros_like(input_ids).to(torch.float).unsqueeze(-1).repeat(1,1,model.embed_dims)
    
    while cur_len < max_length:
        # outputs: (batch_size*num_beams, cur_len, vocab_size)
        feats, outputs, pre_kv_list = model.decode_forward(input_ids[:,-1:], pre_kv_list, pre_len+cur_len, global_only_image)
        next_token_logits = outputs[:, -1, :] / temperature
        scores = F.log_softmax(next_token_logits, dim=-1) # log_softmax

        next_scores = scores + beam_scores[:, None].expand_as(scores)
        next_scores = next_scores.view(
            batch_size, num_beams * vocab_size
        ) 
        next_scores, next_tokens = torch.topk(next_scores, 2*num_beams, dim=1, largest=True, sorted=True)

        feats = feats.view(batch_size, num_beams, -1)
        next_beams = next_tokens // vocab_size
        next_feats = torch.gather(feats, 1, next_beams.unsqueeze(-1).repeat(1,1,model.embed_dims))
        
        next_batch_beam = []

        for batch_idx in range(batch_size):
            if done[batch_idx]:
                next_batch_beam.extend([(0, pad_token_id, 0, torch.zeros_like(feats[0][0]))] * num_beams)  # pad the batch
                continue
            next_sent_beam = []
            for beam_token_rank, (beam_token_id, beam_token_score, beam_token_feat) in enumerate(
                    zip(next_tokens[batch_idx], next_scores[batch_idx], next_feats[batch_idx])
                ):
                beam_id = beam_token_id // vocab_size # 1
                token_id = beam_token_id % vocab_size # 1
                effective_beam_id = batch_idx * num_beams + beam_id
                if (eos_token_id is not None) and (token_id.item() == eos_token_id):
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    generated_hyps[batch_idx].add(
                        input_ids[effective_beam_id].clone(), beam_token_score.item(), his_scores[effective_beam_id], his_feats[effective_beam_id]
                    )

                    # DEBUG
                    # next_sent_beam.append([0,token_id,effective_beam_id])
                    # done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                    #         next_scores[batch_idx].max().item(), cur_len
                    #     ) 
                else:
                    # Debug
                    # if input_ids[effective_beam_id,-1] == eos_token_id:
                    #     beam_token_score = 0 
                    #     token_id = eos_token_id
                    
                    next_sent_beam.append((beam_token_score, token_id, effective_beam_id, beam_token_feat))

                if len(next_sent_beam) == num_beams:
                    break
                
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                            next_scores[batch_idx].max().item(), cur_len
                        ) 
            next_batch_beam.extend(next_sent_beam)
        
        if all(done):
            break
        beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
        beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
        beam_idx = input_ids.new([x[2] for x in next_batch_beam])

        beam_feats = torch.stack([x[3] for x in next_batch_beam],dim=0)
        
        input_ids = input_ids[beam_idx, :] # (num_beams * batch_size, seq_len)
        # (num_beams * batch_size, seq_len) ==> (num_beams * batch_size, seq_len + 1)
        input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
        
        his_scores = his_scores[beam_idx, :]
        his_scores = torch.cat([his_scores, beam_scores.unsqueeze(1)], dim=-1)
        
        his_feats = his_feats[beam_idx, :]
        his_feats = torch.cat([his_feats, beam_feats.unsqueeze(1)], dim=1)

        cur_len = cur_len + 1
        
        pre_kv_list = model.reorder_cache(beam_idx, pre_kv_list)
        
    
    for batch_idx in range(batch_size):
        if done[batch_idx]:
            continue
        for beam_id in range(num_beams):
            
            effective_beam_id = batch_idx * num_beams + beam_id
            final_score = beam_scores[effective_beam_id].item()
            final_tokens = input_ids[effective_beam_id]
            final_his_scores = his_scores[effective_beam_id]
            final_his_feats = his_feats[effective_beam_id]
            generated_hyps[batch_idx].add(final_tokens, final_score, final_his_scores, final_his_feats)
    
    output_num_return_sequences_per_batch = 1
    output_batch_size = output_num_return_sequences_per_batch * batch_size
    
    sent_lengths = input_ids.new(output_batch_size)
    best = []
    best_scores = []
    best_feats = []
    # retrieve best hypotheses
    for i, hypotheses in enumerate(generated_hyps):
        # x: (score, hyp), x[0]: score
        sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
        for j in range(output_num_return_sequences_per_batch):
            effective_batch_idx = output_num_return_sequences_per_batch * i + j
            best_hyp, best_s, best_f = sorted_hyps.pop()[1:]
            sent_lengths[effective_batch_idx] = len(best_hyp)
            best.append(best_hyp)
            best_scores.append(best_s)
            best_feats.append(best_f)
            
    if sent_lengths.min().item() != sent_lengths.max().item():
        # Debug
        sent_max_len = min(sent_lengths.max().item() + 1, max_length+1)
        # fill pad
        decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
        decoded_scores = input_ids.new(output_batch_size, sent_max_len).fill_(0).to(torch.float)
        decoded_feats = input_ids.new(output_batch_size, sent_max_len, model.embed_dims).fill_(0).to(torch.float)
 
        
        for i, hypo in enumerate(best):
            try:
                decoded[i, : sent_lengths[i]] = hypo
            except:
                import ipdb 
                ipdb.set_trace()
            decoded_scores[i, : sent_lengths[i]] = best_scores[i]
            decoded_feats[i, : sent_lengths[i]] = best_feats[i]
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
                decoded[i, sent_lengths[i]] = 0
    else:
        
        decoded = torch.stack(best).type(torch.long)
        decoded_scores = torch.stack(best_scores).type(torch.float)
        decoded_feats = torch.stack(best_feats).type(torch.float)
    decoded = decoded[:,1:]
    decoded_scores = (decoded_scores[:,1:] - decoded_scores[:,:-1]).exp()
    decoded_feats=  decoded_feats[:,1:]
    # (output_batch_size, sent_max_len) ==> (batch_size, sent_max_len)
    return decoded, decoded_scores, decoded_feats