
from src.tree.node import Node
import torch.nn.functional as F
import torch
import heapq
from math import log, exp
from torch.nn.utils.rnn import pad_sequence
from collections import Counter

# Limits for score bounds.
LOWEST_PROB = 1e-12 # Avoid log(0) by using a very small probability
worst_token_score = log(LOWEST_PROB) # Assume very low probability for worst case (gets updated based on vocab size)
BEST_TOKEN_SCORE = log(1.0) # Assume perfect probability for best case
LENGTH_PENALTY = 1.0 # Alpha Penalty for longer sequences, can be tuned (from Google NMT paper)


def add_candidate(node, beam_candidates, best_score_lower_bound, node_score_upper_bound, score, max_beam_size):
    """
    Add a candidate node to the beam candidates if it is promising.
    This function checks if the node's score is promising based on its upper bound and the best score lower bound.
    If the candidate is promising, it is added to the beam candidates heap.
    
    Returns true if the candidate was added, false otherwise.
    """
    # Only add if it's promising (based on normalized score)
    if node_score_upper_bound >= best_score_lower_bound:        
        score_tuple = (score, node)
        
        # Maintain fixed size
        if len(beam_candidates) < max_beam_size:
            heapq.heappush(beam_candidates, score_tuple)
        else:
            heapq.heappushpop(beam_candidates, score_tuple)
            
        return True

    return False

def entropy(probs, normalise=False):
    """
    Computes entropy over the last dimension of `probs`.

    Args:
        probs (Tensor): Probability tensor of shape (..., vocab_size)
        normalise (bool): If True, divide by maximum possible entropy (log vocab_size)

    Returns:
        Tensor: Entropy of shape probs.shape[:-1]
    """
    ent = -torch.sum(probs * torch.log(torch.clamp_min(probs, LOWEST_PROB)), dim=-1)

    if normalise:
        vocab_size = probs.shape[-1]
        max_entropy = torch.log(torch.tensor(vocab_size, dtype=probs.dtype, device=probs.device))
        ent = ent / max_entropy

    return ent


def _greedy_lower_bound_and_tree(model, tokenizer, input_ids, max_new_tokens, eos_token_id=None):
    """
        Run an initial greedy search to get lower bound and initial path.
    """
    best_score_lower_bound = float('-inf')

    with torch.inference_mode():
        greedy_outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            eos_token_id=eos_token_id,
            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_token_id,
            do_sample=False,
            output_scores=True,
            return_dict_in_generate=True,
        )

        generated_tokens = greedy_outputs.sequences[:, input_ids.shape[-1]:]
        scores = greedy_outputs.scores

        total_log_prob = 0.0
        total_tokens = generated_tokens.size(1)

        # Start with root node
        root = Node(input_ids, 0)
        tree = root
        
        max_ent = 0

        for step, logits in enumerate(scores):
            probs = F.softmax(logits[0], dim=-1)
            token_id = generated_tokens[0, step].item()
            token_prob = probs[token_id].item()
            total_log_prob += log(token_prob + LOWEST_PROB)
            
            ent = entropy(probs, normalise=True).item()
            max_ent = max(max_ent, ent)

            # Create new node extending previous
            new_context = torch.cat([tree.context, torch.tensor([[token_id]], device=tree.context.device)], dim=-1)
            new_score = tree.score + log(token_prob + LOWEST_PROB)
            new_node = Node(new_context, new_score, tree)
            tree.add_child(new_node)
            tree = new_node

        if total_tokens > 0:
            best_score_lower_bound = total_log_prob / total_tokens

    return best_score_lower_bound, root, max_ent


def _estimate_bounds(node, max_new_tokens, worst_token_score, best_token_score, eos_token_id):
    
    length_so_far = node.path_length()
    length_remaining = max_new_tokens - length_so_far
    total_length = length_so_far + length_remaining
    
    assert total_length == max_new_tokens

    token = node.context[0, -1]
    score = node.score / (node.path_length() ** LENGTH_PENALTY)
    if token.item() == eos_token_id:
        # This is a complete sequence, update the bounds directly
        score_lower_bound = score
        score_upper_bound = score
    else:
        # Estimate bounds assuming worst/best future paths
        length_remaining = max_new_tokens - node.path_length()
        score_lower_bound = (node.score + length_remaining * worst_token_score) / (max_new_tokens ** LENGTH_PENALTY)
        score_upper_bound = (node.score + length_remaining * best_token_score) / (max_new_tokens ** LENGTH_PENALTY)
    
    return score_lower_bound, score, score_upper_bound

def entropy_guided_beam_search(model, tokenizer, input_ids, max_beam_size=10, max_new_tokens=5, eos_token_id=None):
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id
        
    # Assume uniform vocab for worst case, as this is lowest prob the best token can have
    worst_token_score = log(1 / tokenizer.vocab_size) if tokenizer.vocab_size > 0 else worst_token_score
    
    # Start with greedy to get initial bounds and path
    best_score_lower_bound, tree, max_entropy_seen = _greedy_lower_bound_and_tree(model, tokenizer, input_ids, max_new_tokens, eos_token_id)
    beams = [tree]
    
    # Tracking beam sizes
    counter = Counter()
    pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            
    while beams:
        # Filter out finished beams
        active_nodes = [node for node in beams if node.path_length() < max_new_tokens and node.context[0, -1].item() != eos_token_id]

        if not active_nodes:
            break
                
        # Pad contexts to same length, use attention mask to ignore padding
        contexts = [node.context.squeeze(0) for node in active_nodes]
        padded_contexts = pad_sequence(contexts, batch_first=True, padding_value=pad_value)
        attention_mask = (padded_contexts != pad_value).long()

        # Batch generate the outputs
        with torch.inference_mode():
            outputs = model(
                input_ids=padded_contexts,
                attention_mask=attention_mask,
                eos_token_id=eos_token_id,
                pad_token_id=pad_value,
            )
            
            logits = outputs.logits # All the logits for the current step
            logits = logits[:, -1, :]  # The logits for the last token in each sequence
            probs = F.softmax(logits)     
            entropies = entropy(probs, normalise=True)
                        
            max_entropy_seen = max(max_entropy_seen, entropies.max().item())
                                    
            # Relative to largest seen entropy
            entropies /= max_entropy_seen
            
            # Get top-k tokens per sample
            beam_sizes = (entropies * max_beam_size).round().long().clamp(min=1)
                                                
            # Tracking beamsizes
            counter.update(beam_sizes.tolist())
            
            max_k = beam_sizes.max().item()
            # Does for max beam size here, but below only uses the proper beamsize for each sample
            top_k_probs, top_k_tokens = torch.topk(probs, k=max_k, dim=-1)
            
            del outputs, logits, probs, attention_mask, padded_contexts
                                                     
        beam_candidates = []
        
        for i, node in enumerate(active_nodes):
            beam_size = beam_sizes[i].item()

            for k in range(beam_size):
                prob = top_k_probs[i, k]
                token = top_k_tokens[i, k]

                new_context = torch.cat([node.context, token.view(1, 1)], dim=-1)
                log_prob = torch.log(prob + LOWEST_PROB).item()
                new_score = node.score + log_prob
                
                child = Node(new_context, new_score, node, length_penalty=LENGTH_PENALTY)
                score_lower_bound, score, score_upper_bound = _estimate_bounds(child, max_new_tokens, worst_token_score, BEST_TOKEN_SCORE, eos_token_id)
                
                # Update bound if this is the best seen so far
                best_score_lower_bound = max(best_score_lower_bound, score_lower_bound)
                
                added = add_candidate(
                    child, beam_candidates,
                    best_score_lower_bound=best_score_lower_bound,
                    node_score_upper_bound=score_upper_bound,
                    score=score,
                    max_beam_size=max_beam_size,
                )

                if added:
                    node.add_child(child)
                else:
                    node.parent = None
                    # We can exit early, since sorted according to top k, the upper bounds are monotonic
                    break
                        
        del top_k_probs, top_k_tokens
        beams = [ node for score, node in beam_candidates]
        

    # At the end, we only want to consider the completed paths 
    tree.prune_incomplete(eos_token_id=eos_token_id, max_new_tokens=max_new_tokens)
    
    # For tracking some stats
    tree.beam_sizes = counter
    
    return tree

def entropy_guided_beam_search_api(client, tokenizer, prompt, max_beam_size=10, max_new_tokens=5, top_k=20,
                                   eos_token_id=None):
    
    assert max_beam_size <= top_k, "Beam size must be less than or equal to top-k"
    
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id
    
    
    worst_token_score = log(1 / top_k)
    best_score_lower_bound = float('-inf')
    
    # Start with root node
    root_context = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0)  # shape [1, seq_len]

    root_node = Node(root_context, score=0)
    beams = [root_node]
    counter = Counter()
    max_entropy_seen = 0.0001  # avoid division by zero

    for step in range(max_new_tokens):        
        active_nodes = [node for node in beams if node.context[0, -1].item() != eos_token_id]
        if not active_nodes:
            break

        beam_candidates = []

        for node in active_nodes:
            # Convert context to string prompt
            context_text = tokenizer.decode(node.context[0].tolist())
                        
            # Query LLM
            response = client.responses.create(
                model="", # Isnt actually used
                input=context_text,
                max_output_tokens=1,
                metadata={"logprobs": True, "top_logprobs": top_k}  # specify top-k logprobs
            )
            
            # Assume single choice (1 run)
            response = response.choices[0]
            logprobs = response["logprobs"][0]
            
            top_logprobs = torch.tensor([x["logprob"] for x in logprobs])
            top_token_ids = [x["token_id"] for x in logprobs]
            
            # From log probs to normalized probs
            probs = torch.exp(top_logprobs) 
            probs /= probs.sum()

            # Compute normalized entropy
            ent = entropy(probs)
            max_entropy_seen = max(max_entropy_seen, ent.item())
            ent /= max_entropy_seen

            # Determine beam size for this node
            beam_size = max(1, int((ent * max_beam_size).round().item()))
            counter.update([beam_size])

            # Take top beam_size tokens
            topk_probs, topk_indices = torch.topk(probs, k=beam_size)
                        
            for i in range(beam_size):
                token_id = top_token_ids[topk_indices[i]]
                token_tensor = torch.tensor([[token_id]], dtype=torch.long)
                                
                new_context = torch.cat([node.context, token_tensor], dim=-1)  # shape [1, seq_len + n_tokens]

                log_prob = torch.log(topk_probs[i] + LOWEST_PROB)
                new_score = node.score + log_prob
                child = Node(new_context, score=new_score, parent=node, length_penalty=LENGTH_PENALTY)
                
                score_lower_bound, score, score_upper_bound = _estimate_bounds(child, max_new_tokens, 
                                                                               worst_token_score,
                                                                               BEST_TOKEN_SCORE, eos_token_id)
                
                # Update bound if this is the best seen so far
                best_score_lower_bound = max(best_score_lower_bound, score_lower_bound)
                
                
                added = add_candidate(
                    child, beam_candidates,
                    best_score_lower_bound=best_score_lower_bound,
                    node_score_upper_bound=score_upper_bound,
                    score=score,
                    max_beam_size=max_beam_size,
                )
                
                if added:
                    node.add_child(child)
                    
                else:
                    node.parent = None
                    # We can exit early, since sorted according to top k, the upper bounds are monotonic
                    break
                
                #beam_candidates.append((new_score, child))

        # Sort candidates and keep top beams
        beam_candidates.sort(key=lambda x: x[0], reverse=True)
        beams = [c for _, c in beam_candidates[:max_beam_size]]

    # Prune incomplete paths
    root_node.prune_incomplete(eos_token_id, max_new_tokens)
    root_node.beam_sizes = counter

    return root_node
