import copy
import random

# typing 
from typing import List, Tuple
import time
import torch
import torch.nn.functional as F

# TODO
# from transformers import LlamaTokenizer
# tokenizer=LlamaTokenizer.from_pretrained("/home/lyh/weights/hf/vicuna_v13/7B/")

TOPK = 10  # topk for sparse tree

from transformers.generation.logits_process import (
    LogitsProcessorList,
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)

def cfg_logit_process(combined_logits, cfg_scale=4.0, cfg_interval=-1):
    if combined_logits.dim() == 4:
        combined_logits = combined_logits.transpose(0, 1)
    cond_logits, uncond_logits = torch.split(combined_logits, len(combined_logits) // 2, dim=0)
    logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
    if combined_logits.dim() == 4:
        logits = logits.transpose(0, 1)
    return logits
def calculate_jsd(tensor1, tensor2):
    m = 0.5 * (tensor1 + tensor2)
    
    # Compute Kullback-Leibler Divergence between each tensor and the average distribution
    kld1 = F.kl_div(m.log(), tensor1, reduction='none')  # KL(P || M)
    kld2 = F.kl_div(m.log(), tensor2, reduction='none')  # KL(Q || M)
    
    # Jensen-Shannon Divergence is the average of the two KLDs
    jsd = 0.5 * (kld1 + kld2)
    
    return jsd

def calculate_tvd(tensor1, tensor2):
    tvd = 0.5 * torch.abs(tensor1 - tensor2)
    return tvd


class Timer:
    def __init__(self,name):
        self.name = name
    def __enter__(self):
        torch.cuda.synchronize()
        self.start = time.perf_counter()


    def __exit__(self, exc_type, exc_value, traceback):
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - self.start
        print(f'{self.name} took {elapsed} seconds')


def prepare_logits_processor(
        temperature: float = 0.0,
        repetition_penalty: float = 0.0,
        top_p: float = 0.0,
        top_k: int = 0
) -> LogitsProcessorList:
    processor_list = LogitsProcessorList()
    if temperature > 1e-5:
        if temperature >= 1e-5 and temperature != 1.0:
            processor_list.append(TemperatureLogitsWarper(temperature))
        if repetition_penalty > 1.0:
            processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
        if 1e-8 <= top_p < 1.0:
            processor_list.append(TopPLogitsWarper(top_p))
        if top_k > 0:
            processor_list.append(TopKLogitsWarper(top_k))
    return processor_list


# test_processor = prepare_logits_processor(
#         0.0, 0.0, -1, 1
#     )


def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]:
    """
    Pad the given path list with a specific value up to a specified length.

    Parameters:
    - path (list): The original list that needs padding.
    - length (int): The desired length of the padded list.
    - pad_value (optional, default=-2): The value to use for padding.

    Returns:
    - list: A new list based on the original path but padded to the desired length.

    Example:
    >>> pad_path([1,2,3], 5)
    [1, 2, 3, -2, -2]

    Note:
    If the given path is already longer than the specified length,
    then no padding occurs, and the original path is returned.
    """

    # Calculate the number of padding values needed by subtracting the length
    # of the path from the desired length.
    # Append the padding values to the original path and return the new list.
    return path + [pad_value] * (length - len(path))


def generate_tree_buffers(tree_choices, device="cuda"):
    def custom_sort(lst):
        # sort_keys=[len(list)]
        sort_keys = []
        for i in range(len(lst)):
            sort_keys.append(lst[i] if lst[i] >= 0 else maxitem)
        return sort_keys
    with Timer("sort"):

        sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
        tree_len = len(sorted_tree_choices) + 1

    # Initialize depth_counts to keep track of how many choices have a particular depth
        depth_counts = []
        prev_depth = 0
        for path in sorted_tree_choices:
            depth = len(path)
            if depth != prev_depth:
                depth_counts.append(0)
            depth_counts[depth - 1] += 1
            prev_depth = depth

        tree_attn_mask = torch.eye(tree_len, tree_len)
        tree_attn_mask[:, 0] = 1
        start = 0
        for i in range(len(depth_counts)):
            for j in range(depth_counts[i]):
                cur_tree_choice = sorted_tree_choices[start + j]
                # retrieve ancestor position
                if len(cur_tree_choice) == 1:
                    continue
                ancestor_idx = []
                for c in range(len(cur_tree_choice) - 1):
                    ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
                tree_attn_mask[j + start + 1, ancestor_idx] = 1
            start += depth_counts[i]

        tree_indices = torch.zeros(tree_len, dtype=torch.long)
        p_indices = [0 for _ in range(tree_len - 1)]
        b_indices = [[] for _ in range(tree_len - 1)]
        tree_indices[0] = 0
        start = 0
        bias = 0
        for i in range(len(depth_counts)):
            inlayer_bias = 0
            b = []
            for j in range(depth_counts[i]):
                cur_tree_choice = sorted_tree_choices[start + j]
                cur_parent = cur_tree_choice[:-1]
                if j != 0:
                    if cur_parent != parent:
                        bias += 1
                        inlayer_bias += 1
                        parent = cur_parent
                        b = []
                else:
                    parent = cur_parent
                tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1
                p_indices[start + j] = inlayer_bias
                if len(b) > 0:
                    b_indices[start + j] = copy.deepcopy(b)
                else:
                    b_indices[start + j] = []
                b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1)
            start += depth_counts[i]

        p_indices = [-1] + p_indices
        tree_position_ids = torch.zeros(tree_len, dtype=torch.long)
        start = 0
        for i in range(len(depth_counts)):
            tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
            start += depth_counts[i]

        retrieve_indices_nest = []
        retrieve_paths = []
        for i in range(len(sorted_tree_choices)):
            cur_tree_choice = sorted_tree_choices[-i - 1]
            retrieve_indice = []
            if cur_tree_choice in retrieve_paths:
                continue
            else:
                for c in range(len(cur_tree_choice)):
                    retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]))
                    retrieve_paths.append(cur_tree_choice[:c + 1])
            retrieve_indices_nest.append(retrieve_indice)
        max_length = max([len(x) for x in retrieve_indices_nest])
        retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
        retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
        retrieve_indices = retrieve_indices + 1
        retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices],
                                     dim=1)

        maxitem = retrieve_indices.max().item() + 5



        retrieve_indices = retrieve_indices.tolist()
        retrieve_indices = sorted(retrieve_indices, key=custom_sort)
        retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)



    # Aggregate the generated buffers into a dictionary
    tree_buffers = {
        "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0),
        "tree_indices": tree_indices,
        "tree_position_ids": tree_position_ids,
        "retrieve_indices": retrieve_indices,
    }

    # Move the tensors in the dictionary to the specified device
    tree_buffers = {
        k: v.clone().to(device)
        if isinstance(v, torch.Tensor)
        else torch.tensor(v, device=device)
        for k, v in tree_buffers.items()
    }

    return tree_buffers


def initialize_tree0(input_ids, model, past_key_values, logits_processor):
    draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, logits, hidden_state, sample_token = model(
        input_ids, past_key_values=past_key_values, output_orig=True, logits_processor=logits_processor
    )

    #     if logits_processor is not None:
    #         logits = orig[:, -1]
    #         logits = logits_processor(None, logits)
    #         probabilities = torch.nn.functional.softmax(logits, dim=1)
    #         token = torch.multinomial(probabilities, 1)
    #     else:
    #         token = torch.argmax(orig[:, -1])
    #         token = token[None, None]
    #     input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
    #     # Clone the output hidden states
    #
    #     draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head)
    #     if output_orig:
    #         return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token
    #     return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token
    return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token

def initialize_tree(cond_idx, model, past_key_values, logits_processor, cfg_scale, cfg_interval, attention_mask = None):
    outputs, orig, hidden_states = model(
        cond_idx=cond_idx, past_key_values=past_key_values, output_orig=True, attention_mask=attention_mask
    )

    logits = cfg_logit_process(orig[:, -1], cfg_scale, cfg_interval)

    if logits_processor is not None:
        logits = logits_processor(None, logits)
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        token = torch.multinomial(probabilities, 1)
    else:
        token = torch.argmax(logits)
        token = token[None, None]
    token = torch.cat([token, token], dim=0)
    if attention_mask is None:
        input_ids = torch.cat((cond_idx.unsqueeze(-1), token.to(cond_idx.device)), dim=1)
    else:
        zero_padding = torch.zeros((token.shape[0], 120), dtype=torch.long, device=token.device)
        input_ids = torch.cat((zero_padding, token.to(cond_idx.device)), dim=1)
    # Clone the output hidden states

    draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor, cfg_scale, cfg_interval)
    return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token

def initialize_tree_sample(cond_idx, model, past_key_values, logits_processor, cfg_scale, cfg_interval):
    outputs, orig, hidden_states = model(
        cond_idx=cond_idx, past_key_values=past_key_values, output_orig=True
    )

    logits = cfg_logit_process(orig[:, -1], cfg_scale, cfg_interval)

    if logits_processor is not None:
        logits = logits_processor(None, logits)
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        token = torch.multinomial(probabilities, 1)
    else:
        token = torch.argmax(logits)
        token = token[None, None]
    token = torch.cat([token, token], dim=0)
    input_ids = torch.cat((cond_idx.unsqueeze(-1), token.to(cond_idx.device)), dim=1)
    # Clone the output hidden states

    draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.sample_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor, cfg_scale, cfg_interval)
    return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token


def reset_tree_mode(
        model,
):
    model.base_model.model.tree_mask = None
    model.base_model.model.tree_mode = None


def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Resets the current lengths in the passed key-values to zero.

    This function is designed to be used during the evaluation of a baseline model.
    It iterates through each layer's key-values and sets their current lengths to zero,
    effectively resetting their state.

    Args:
    - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.

    Returns:
    - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
    """
    for i in range(len(passed_key_values)):
        for j in range(2):
            passed_key_values[i][j].current_length.fill_(0)
    return passed_key_values


def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor):
    sample_token = sample_token.to(tree_indices.device)

    candidates_logit = sample_token[0]

    candidates_tree_logits = tree_logits

    candidates = torch.cat([candidates_logit, candidates_tree_logits.view(-1)], dim=-1)

    tree_candidates = candidates[tree_indices]

    tree_candidates_ext = torch.cat(
        [tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device) - 1], dim=0)

    cart_candidates = tree_candidates_ext[retrieve_indices]


    # Unsqueeze the tree candidates for dimension consistency.
    tree_candidates = tree_candidates.unsqueeze(0)
    return cart_candidates,  tree_candidates


def tree_decoding(
        model,
        tree_candidates,
        past_key_values,
        tree_position_ids,
        input_ids,
        retrieve_indices,
        cfg_scale,
        cfg_interval,
        attention_mask = None
):
    position_ids = tree_position_ids + input_ids.shape[1]
    if attention_mask is not None:
        remaining_length = input_ids.shape[1] + tree_candidates.shape[1] - attention_mask.shape[1]
        one_padding = torch.ones((attention_mask.shape[0], remaining_length), dtype=torch.long, device=attention_mask.device)
        attention_mask = torch.cat([attention_mask, one_padding], dim=1)
    outputs, tree_logits, hidden_state = model(
        input_ids=tree_candidates,
        output_orig=True,
        past_key_values=past_key_values,
        position_ids=position_ids,
        attention_mask=attention_mask
    )
    tree_logits = cfg_logit_process(tree_logits, cfg_scale, cfg_interval)
    logits = tree_logits[0, retrieve_indices]
    return logits, hidden_state, outputs






def evaluate_posterior(
        logits: torch.Tensor,
        candidates: torch.Tensor,
        logits_processor,
):
    """
    Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.

    Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
    probabilities to select the best candidate.

    Args:
    - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
    - candidates (torch.Tensor): Candidate token sequences.
    - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
    - posterior_threshold (float): Threshold for posterior probability.
    - posterior_alpha (float): Scaling factor for the threshold.

    Returns:
    - best_candidate (torch.Tensor): Index of the chosen best candidate.
    - accept_length (int): Length of the accepted candidate sequence.
    """
    # Greedy decoding based on temperature value
    if logits_processor is None:
        # Find the tokens that match the maximum logits for each position in the sequence
        posterior_mask = (
                candidates[:, 1:].to(logits.device) == torch.argmax(logits[:, :-1], dim=-1)
        ).int()
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
        accept_length = candidates_accept_length.max()
        # Choose the best candidate
        if accept_length == 0:
            # Default to the first candidate if none are accepted
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
        return best_candidate, accept_length, logits[best_candidate, accept_length]

    else:
        accept_length = 1
        accept_cand = candidates[0][:1]
        best_candidate = 0
        for i in range(1, candidates.shape[1]):
            if i != accept_length:
                break
            adjustflag = False
            is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1)
            fi = torch.nonzero(is_eq, as_tuple=True)[0][0]
            gt_logits = logits[fi, i - 1][None]
            gt_logits = logits_processor(None, gt_logits)[0]
            gtp = torch.softmax(gt_logits, dim=0)
            candidates_set = []
            for j in range(candidates.shape[0]):
                if is_eq[j]:
                    x = candidates[j, i]
                    xi = x.item()
                    if xi in candidates_set or xi == -1:
                        continue
                    candidates_set.append(xi)
                    r = random.random()
                    px = gtp[xi]
                    qx = 1.0
                    acp = px / qx
                    if r <= acp:
                        accept_cand = torch.cat((accept_cand, x[None]), dim=0)
                        accept_length += 1
                        best_candidate = j
                        break
                    else:
                        gtp[xi] = 0
                        gtp = gtp / gtp.sum()
                        adjustflag = True
        if adjustflag and accept_length != candidates.shape[1]:
            sample_p = gtp
        else:
            gt_logits = logits[best_candidate, accept_length - 1]
            sample_p = torch.softmax(gt_logits, dim=0)
        return torch.tensor(best_candidate), accept_length - 1, sample_p
    
def evaluate_posterior_with_nearest_latent(
        logits: torch.Tensor,
        candidates: torch.Tensor,
        logits_processor,
        nearest_latent,
        adaptive_func, 
        coeff_a,
        coeff_b,
        warmup_steps,
        current_length,
):
    """
    Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.

    Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
    probabilities to select the best candidate.

    Args:
    - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
    - candidates (torch.Tensor): Candidate token sequences.
    - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
    - posterior_threshold (float): Threshold for posterior probability.
    - posterior_alpha (float): Scaling factor for the threshold.

    Returns:
    - best_candidate (torch.Tensor): Index of the chosen best candidate.
    - accept_length (int): Length of the accepted candidate sequence.
    """
    # Greedy decoding based on temperature value
    if logits_processor is None:
        device = logits.device
        batch_size, seq_len, vocab_size = logits.size()
        candidates_verify = candidates[:, 1:]  # Shape: (batch_size, seq_len)

        # Compute softmax probabilities over logits
        gtp = torch.softmax(logits, dim=-1)  # Shape: (batch_size, seq_len, vocab_size)

        # Get the token indices from candidates
        xi = candidates_verify  # Shape: (batch_size, seq_len)

        # Mask for positions where xi == -1
        valid_mask = (xi != -1).to(device)  # Shape: (batch_size, seq_len)

        # Adjust xi to have valid indices for indexing operations
        xi_valid = xi.clone()
        xi_valid[~valid_mask] = 0  # Replace invalid indices with 0 (or any valid index)

        # Gather probabilities of xi
        px = gtp.gather(dim=-1, index=xi_valid.unsqueeze(-1)).squeeze(-1)  # Shape: (batch_size, seq_len)
        px = px * valid_mask  
        if warmup_steps > current_length:
            # Greedy decoding
            top_tokens = torch.argmax(logits[:, :-1], dim=-1)  # Shape: (batch_size, seq_len)
            posterior_mask = (xi == top_tokens).int() * valid_mask
            candidates_accept_length = torch.cumprod(posterior_mask, dim=1).sum(dim=1)
            accept_length = candidates_accept_length.max()
        else:
            # Adaptive decoding with nearest latent tokens
            search_space = int(coeff_b)
            nearest_indices = nearest_latent[xi_valid]  # Shape: (batch_size, seq_len, k)
            nearest_indices = nearest_indices[:, :, :search_space]  # Limit search space

            # For invalid positions, set nearest_indices to zero
            nearest_indices[~valid_mask.unsqueeze(-1).expand_as(nearest_indices)] = 0

            # Get probabilities of nearest latent tokens
            nearest_probs = gtp.gather(dim=-1, index=nearest_indices)  # Shape: (batch_size, seq_len, search_space)
            nearest_probs = nearest_probs * valid_mask.unsqueeze(-1)  # Zero out invalid positions

            # Compute cumulative sum of nearest probabilities
            cumsum_nearest_probs = torch.cumsum(nearest_probs, dim=-1)  # Shape: (batch_size, seq_len, search_space)

            # Prepare target and approximate distributions
            px_expanded = px.unsqueeze(-1).repeat(1, 1, search_space)  # Shape: (batch_size, seq_len, search_space)
            approx_p = px_expanded + cumsum_nearest_probs  # Shape: (batch_size, seq_len, search_space)
            approx_p = approx_p * valid_mask.unsqueeze(-1)  # Zero out invalid positions

            # Concatenate distributions for JS divergence
            target_p = torch.cat([px_expanded, nearest_probs], dim=-1)  # Shape: (batch_size, seq_len, 2 * search_space)
            approx_p_full = torch.cat([approx_p, torch.zeros_like(nearest_probs)], dim=-1)

            # Zero out invalid positions in target and approximate distributions
            target_p = target_p * valid_mask.unsqueeze(-1).to(torch.float32)
            approx_p_full = approx_p_full * valid_mask.unsqueeze(-1).to(torch.float32)

            # Compute JS divergence
            if adaptive_func == 'tvd':
                jsd = calculate_tvd(target_p, approx_p_full)
            else:
                jsd = calculate_jsd(target_p, approx_p_full)  # Implement this function accordingly
        
            jsd = torch.nan_to_num(jsd, nan=0.0)
            jsd_px = jsd[:, :, :search_space]
            jsd_cumsum = torch.cumsum(jsd[:, :, search_space:], dim=-1)
            jsd = jsd_px + jsd_cumsum
            # For invalid positions, set jsd to a high value to avoid selecting them
            jsd[~valid_mask] = float('inf')

            # Determine indices where JS divergence exceeds threshold
            # Create a boolean mask where jsd does not exceed coeff_a
            jsd_not_exceeds = (jsd <= coeff_a)

            # Get the size of the last dimension
            dim_size = jsd.shape[-1]

            # Create indices for the last dimension
            indices = torch.arange(dim_size).unsqueeze(0).unsqueeze(0).to(jsd.device)
            indices = indices.expand(jsd.shape[0], jsd.shape[1], dim_size)

            # Use the mask to select valid indices, set invalid positions to -1
            masked_indices = torch.where(jsd_not_exceeds, indices, torch.full_like(indices, -1))

            # Find the maximum valid index for each (batch_size, seq_len)
            indices = masked_indices.max(dim=-1)[0]

            # Update probabilities based on indices
            idx_mask = (indices >= 0)
            idx_values = indices * idx_mask
            idx_values = idx_values.unsqueeze(-1)

            # Handle positions where idx_values == -1
            px_adjusted = torch.where(
                idx_mask,
                approx_p.gather(dim=-1, index=idx_values).squeeze(-1),
                px
            )
            px_adjusted = px_adjusted * valid_mask  # Zero out invalid positions

            # Update gtp with adjusted probabilities
            gtp.scatter_(dim=-1, index=xi_valid.unsqueeze(-1), src=px_adjusted.unsqueeze(-1))

            # Compute posterior mask
            top_tokens = torch.argmax(gtp, dim=-1)[:, :-1]  # Adjusted to match xi
            posterior_mask = (xi == top_tokens).int() * valid_mask
            candidates_accept_length = torch.cumprod(posterior_mask, dim=1).sum(dim=1)
            accept_length = candidates_accept_length.max()

        # Choose the best candidate
        if accept_length == 0:
            best_candidate = torch.tensor(0, dtype=torch.long, device=device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)

        return best_candidate, accept_length, logits[best_candidate, accept_length]


    else:
        accept_length = 1
        accept_cand = candidates[0][:1]
        best_candidate = 0
        for i in range(1, candidates.shape[1]):
            if i != accept_length:
                break
            adjustflag = False
            is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1)
            fi = torch.nonzero(is_eq, as_tuple=True)[0][0]
            gt_logits = logits[fi, i - 1][None]
            gt_logits = logits_processor(None, gt_logits)[0]
            gtp = torch.softmax(gt_logits, dim=0)
            # entropy = -torch.sum(gtp * torch.log(gtp + 1e-6), dim=-1)
            
            candidates_set = []
            indices = None
            for j in range(candidates.shape[0]):
                if is_eq[j]:
                    x = candidates[j, i]
                    xi = x.item()
                    if xi in candidates_set or xi == -1:
                        continue
                    candidates_set.append(xi)
                    r = random.random()
                    if current_length + i < warmup_steps:
                        px = gtp[xi]
                    else:
                        search_space = int(coeff_b)
                        px = gtp[xi]
                        # if px == 0:
                        #     continue
                        if "jsd" in adaptive_func:
                            approx_p = torch.zeros([2*search_space, 1]).to(gtp.device)
                            target_p = torch.zeros([2*search_space, 1]).to(gtp.device)
                            target_p[:search_space] = px.reshape(1, 1).repeat(search_space, 1)
                            approx_p[:search_space] = px.reshape(1, 1).repeat(search_space, 1)
                            nearest_probs = gtp[nearest_latent[xi, :search_space]].reshape(search_space, 1)
                            target_p[search_space:] = nearest_probs.to(torch.float32)
                            approx_p[:search_space] += torch.cumsum(nearest_probs, dim=0).to(torch.float32)
                            jsd = calculate_jsd(target_p, approx_p)
                            jsd = torch.nan_to_num(jsd, nan=0.0)
                            jsd_px = jsd[:search_space]
                            jsd_cumsum = torch.cumsum(jsd[search_space:], dim=0)
                            jsd = jsd_px + jsd_cumsum
                            indices = (jsd > coeff_a).nonzero(as_tuple=True)[0]
                            if indices.numel() == 0:
                                indices = torch.tensor([search_space- 1])
                            else:
                                indices = indices[0]
                                indices -= 1
                            if indices == -1:
                                px = px
                            else:
                                px = approx_p[indices, 0]

                        elif "tvd" in adaptive_func:
                            nearest_probs = gtp[nearest_latent[xi, :search_space]].reshape(search_space, 1)
                            cumsum_nearest_probs = torch.cumsum(nearest_probs, dim=0)
                            # take the largest index where the cumulative sum is less than coeff_a
                            indices = (cumsum_nearest_probs <= coeff_a).nonzero(as_tuple=True)[0]
                            if indices.numel() == 0:
                                indices = -1
                            else:
                                indices = indices[-1]
                            if indices == -1:
                                px = px
                            else:
                                px = px + cumsum_nearest_probs[indices]
                    
                    qx = 1.0
                    acp = px / qx
                    if r <= acp:
                        accept_cand = torch.cat((accept_cand, x[None]), dim=0)
                        accept_length += 1
                        best_candidate = j
                        break
                    else:
                        gtp[xi] = 0
                        if 'tvd' in adaptive_func or 'jsd' in adaptive_func:
                            if (indices is not None) and (indices != -1):
                                gtp[nearest_latent[xi, :indices+1]] = 0
                        if gtp.sum() == 0:
                            gtp = torch.ones_like(gtp)
                        gtp = gtp / gtp.sum()
                        adjustflag = True
        if adjustflag and accept_length != candidates.shape[1]:
            sample_p = gtp
        else:
            gt_logits = logits[best_candidate, accept_length - 1]
            sample_p = torch.softmax(gt_logits, dim=0)
        return torch.tensor(best_candidate), accept_length - 1, sample_p

def evaluate_posterior_loose(
        logits: torch.Tensor,
        candidates: torch.Tensor,
        logits_processor,
        accept_k
):
    """
    Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.

    Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
    probabilities to select the best candidate.

    Args:
    - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
    - candidates (torch.Tensor): Candidate token sequences.
    - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
    - posterior_threshold (float): Threshold for posterior probability.
    - posterior_alpha (float): Scaling factor for the threshold.

    Returns:
    - best_candidate (torch.Tensor): Index of the chosen best candidate.
    - accept_length (int): Length of the accepted candidate sequence.
    """
    # Greedy decoding based on temperature value
    # if logits_processor is None:
    if True:
        # Find the tokens that match the maximum logits for each position in the sequence

        topk_indices = torch.topk(logits[:, :-1], k=accept_k, dim=-1).indices  # Shape: bsz x seq_len x k

        candidates_expanded = candidates[:, 1:].unsqueeze(-1)  # Shape: bsz x seq_len x 1

        posterior_mask = (candidates_expanded.to(logits.device) == topk_indices).any(dim=-1).int()  # Shape: bsz x seq_len
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
        accept_length = candidates_accept_length.max()
        # Choose the best candidate
        if accept_length == 0:
            # Default to the first candidate if none are accepted
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
        if logits_processor is None:
            return best_candidate, accept_length, logits[best_candidate, accept_length]
        else:
            return best_candidate, accept_length, torch.softmax(logits[best_candidate, accept_length], dim=0)

    else:
        accept_length = 1
        accept_cand = candidates[0][:1]
        best_candidate = 0
        for i in range(1, candidates.shape[1]):
            if i != accept_length:
                break
            adjustflag = False
            is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1)
            fi = torch.nonzero(is_eq, as_tuple=True)[0][0]
            gt_logits = logits[fi, i - 1][None]
            gt_logits = logits_processor(None, gt_logits)[0]
            gtp = torch.softmax(gt_logits, dim=0)
            candidates_set = []
            for j in range(candidates.shape[0]):
                if is_eq[j]:
                    x = candidates[j, i]
                    xi = x.item()
                    if xi in candidates_set or xi == -1:
                        continue
                    candidates_set.append(xi)
                    r = random.random()
                    px = gtp[xi]
                    qx = 1.0
                    acp = px / qx
                    if r <= acp:
                        accept_cand = torch.cat((accept_cand, x[None]), dim=0)
                        accept_length += 1
                        best_candidate = j
                        break
                    else:
                        gtp[xi] = 0
                        gtp = gtp / gtp.sum()
                        adjustflag = True
        if adjustflag and accept_length != candidates.shape[1]:
            sample_p = gtp
        else:
            gt_logits = logits[best_candidate, accept_length - 1]
            sample_p = torch.softmax(gt_logits, dim=0)
        return torch.tensor(best_candidate), accept_length - 1, sample_p

@torch.no_grad()
def update_inference_inputs(
        input_ids,
        candidates,
        best_candidate,
        accept_length,
        retrieve_indices,
        logits_processor,
        new_token,
        past_key_values_data_list,
        current_length_data,
        model,
        hidden_state_new,
        sample_p,
        cfg_scale,
        cfg_interval,
):
    prev_input_len = input_ids.shape[1]
    # if not (input_ids.shape[0] == 2 and input_ids.shape[1] == 1):
    #     prev_input_len += 1
    # Map the best candidate indices to the original indices in the sequence
    select_indices = (
            retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
    )
    # Append the tokens from the best candidate to the input sequence
    if input_ids.shape[1] == 1 and input_ids.shape[0]==2:
        # input_ids = torch.cat(
        #     [candidates[None, best_candidate, : accept_length + 1], candidates[None, best_candidate, : accept_length + 1]]
        # )
        input_ids = input_ids[:1]

    input_ids = torch.cat(
        [input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1
    )
    # Update the past key values based on the selected tokens
    # Source tensor that contains relevant past information based on the selected candidate
    for past_key_values_data in past_key_values_data_list:
        tgt = past_key_values_data[..., select_indices.to(past_key_values_data.device), :]
        # Destination tensor where the relevant past information will be stored
        dst = past_key_values_data[..., prev_input_len: prev_input_len + tgt.shape[-2], :]
        # Copy relevant past information from the source to the destination
        dst.copy_(tgt, non_blocking=True)

    # Update the current length tensor (currently only support batch size is 1)
    current_length_data.fill_(prev_input_len + tgt.shape[-2])

    retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices]
    accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1]
    # token=model.base_model.lm_head(accept_hidden_state_new[:,-1]).argmax()
    # token=token[None,None]
    prob = sample_p
    if logits_processor is not None:
        token = torch.multinomial(prob, 1)
        token = token[None]
    else:
        token = torch.argmax(prob)
        token = token[None, None]
    # hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1)
    ea_input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1).repeat(2, 1)
    draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
                                              input_ids=ea_input_ids,
                                              head=model.base_model.lm_head,logits_processor=logits_processor,
                                              cfg_scale=cfg_scale, cfg_interval=cfg_interval)


    new_token += accept_length + 1

    return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token


if __name__ == "__main__":
    logits = torch.randn(1, 5)
    tp = prepare_logits_processor(0.9, 0, 0.9, 0)
    l = tp(None, logits)
    if tp is None:
        print(tp)
