import warnings
from dataclasses import dataclass
from typing import Callable, List, Literal, Optional, Tuple, Union

import torch
import math
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
import torch.nn.functional as F
import random
from collections import OrderedDict
from safetensors.torch import load_file

from utils  import norm_logits


def top_p_filter(probs: torch.Tensor, p: float = 0.9):

    if p<1:
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=0)


        cutoff = cumulative_probs >= p
        cutoff_idx = torch.nonzero(cutoff, as_tuple=False)[0].item()


        selected_indices = sorted_indices[:cutoff_idx + 1]
        selected_probs = probs[selected_indices]
        return selected_indices, selected_probs
    else:
        return torch.arange(len(probs)), probs



def entropy(p: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:

    p = p / p.sum(dim=-1, keepdim=True)
    log_p = torch.log(p + eps)
    H = -(p * log_p).sum(dim=-1)
    return H


def pick_width_by_entropy_bucket(H, K, width,cuts=(0.25, 0.50, 0.75)):

    H_norm = max(0.0, min(1.0, H / math.log(K)))
    if H_norm < cuts[0]:
        return width[0]
    elif H_norm < cuts[1]:
        return width[1]
    elif H_norm < cuts[2]:
        return width[2]
    else:
        return width[3]


def pick_width_by_entropy_linear(H, K, width):
    import math
    H_norm = max(0, min(1, H/math.log(K)))
    width_next = width[0] + (width[-1] - width[0]) * H_norm

    return max(width[0], min(int(width_next), width[-1]))



def pick_width_by_entropy_continuous(H, K, width, gamma=1.0):

    H_norm = max(0.0, min(1.0, H / math.log(K)))
    B_star = width[0] + (width[-1] - width[0]) * (H_norm ** gamma)

    return max( width[0], min(int(B_star), width[-1]) )



def pick_width_by_entropy_neff(H, K,width, alpha=1.0):
    import math
    neff = math.exp(H)

    r = (neff-1)/(min(K, len(width)*K) - 1)
    width_next = width[0] + (width[-1]-width[0]) * (alpha*r)

    return max(width[0], min(int(width_next), width[-1]))


@dataclass
class DecoderOnlyDraftOutput(ModelOutput):

    sequences: torch.LongTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    cand_probs: Optional[Tuple[torch.FloatTensor]] = None
    tree_att_mask: Optional[torch.FloatTensor] = None
    adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


@dataclass
class DecoderOnlyVerificationOutput(ModelOutput):


    sequences: torch.LongTensor = None
    target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    adapter_past_key_values:Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    acceptance_count: Optional[int] = None


def _MCNS(
    ground_probs: torch.FloatTensor,
    cand_probs: Tuple[torch.FloatTensor],
    cand_tokens: torch.LongTensor,
) -> Optional[int]:
    ground_token = torch.multinomial(ground_probs, num_samples=1).item()

    for check_idx, cand_token in enumerate(cand_tokens):
        if ground_token == cand_token:
            return check_idx
    ground_probs[:] = 0
    ground_probs[ground_token] = 1
    return None


def _MCSSwoReplacement(
    ground_probs: torch.FloatTensor,
    cand_probs: Tuple[torch.FloatTensor],
    cand_tokens: torch.LongTensor,
) -> Optional[int]:
    cand_probs = cand_probs.to(ground_probs.device)
    for check_idx, cand_token in enumerate(cand_tokens):
        accept_threshold = ground_probs[cand_token] / cand_probs[cand_token]
        if torch.rand(1, device=accept_threshold.device) <= accept_threshold:
            return check_idx
        else:
            ground_probs -= cand_probs
            ground_probs = torch.nn.functional.relu(ground_probs, inplace=True)
            ground_probs /= ground_probs.sum()
            cand_probs[cand_token] = 0
            cand_probs = cand_probs / cand_probs.sum()
    return None


def _MCSSwReplacement(
    ground_probs: torch.FloatTensor,
    cand_probs: Tuple[torch.FloatTensor],
    cand_tokens: torch.LongTensor,
) -> Optional[int]:
    cand_probs = cand_probs.to(ground_probs.device)
    for check_idx, cand_token in enumerate(cand_tokens):
        accept_threshold = ground_probs[cand_token] / cand_probs[cand_token]
        if torch.rand(1, device=accept_threshold.device) <= accept_threshold:
            return check_idx
        else:
            ground_probs -= cand_probs
            ground_probs = torch.nn.functional.relu(ground_probs, inplace=True)
            ground_probs /= ground_probs.sum()
    return None


class Strategy:
    def __init__(
        self,
        draft_model,
        target_model,
        k_config: Tuple[int],
        draft_model_temp: float = 1,
        target_model_temp: float = 1,
        replacement: bool = False,
        speculative_sampling: bool = True,
        top_k: int = 10,
        top_p: float = 0.9,
    ) -> None:
        self.k_config = k_config
        self.draft_model = draft_model
        self.target_model = target_model
        self.draft_model_device = draft_model.model.get_input_embeddings().weight.device
        self.target_model_device = (target_model.model.get_input_embeddings().weight.device)

        self.max_draft_len = len(k_config)
        self.draft_model_temp = draft_model_temp
        self.target_model_temp = target_model_temp
        self.replacement = replacement
        self.speculative_sampling = speculative_sampling
        self.top_k = top_k
        self.top_p = top_p

        self.acceptance_check: Callable[
            [torch.FloatTensor, Tuple[torch.FloatTensor], torch.LongTensor],
            Optional[int],
        ] = None
        if speculative_sampling:
            if replacement:
                self.acceptance_check = _MCSSwReplacement
            else:
                self.acceptance_check = _MCSSwoReplacement

        else:
            self.acceptance_check = _MCNS

    def generate_draft(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyDraftOutput:
        raise NotImplementedError


    def generate_draft_dynwidth(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        width_list: Optional[List[int]] = None,
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyDraftOutput:
        raise NotImplementedError

    def acceptance_check(self, ground_probs, cand_probs, cand_tokens) -> Optional[int]:
        raise NotImplementedError

    def verify(
        self,
        input_ids: torch.LongTensor,
        target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        cand_probs: Optional[Tuple[torch.FloatTensor]],
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyVerificationOutput:
        raise NotImplementedError


class BatchStrategy(Strategy):
    def __init__(
        self,
        draft_model,
        target_model,
        k_config: Tuple[int],
        draft_model_temp=1,
        target_model_temp=1,
        replacement: bool = False,
        speculative_sampling: bool = True,
        top_k = 10,
        top_p = 0.9,
    ) -> None:
        super().__init__(
            draft_model,
            target_model,
            k_config,
            draft_model_temp,
            target_model_temp,
            replacement,
            speculative_sampling,
            top_k,
            top_p,
        )

        reversed_prod_size = [1]
        for i in range(1, self.max_draft_len):
            reversed_prod_size.insert(0, reversed_prod_size[0] * k_config[-i])

        self.reversed_prod_size = reversed_prod_size

    def generate_draft(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyDraftOutput:
        input_ids = input_ids.to(self.draft_model_device)
        cand_probs = []
        for step in range(self.max_draft_len):
            step_k = self.k_config[step]
            if past_key_values is not None:
                pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
            else:
                pruned_input_ids = input_ids

            outputs: BaseModelOutputWithPast = self.draft_model.model(
                input_ids=pruned_input_ids,
                use_cache=True,
                past_key_values=past_key_values,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

            hidden_states = outputs.last_hidden_state

            logits = self.draft_model.lm_head(hidden_states[:, -1])

            past_key_values = list(outputs.past_key_values)

            if self.draft_model_temp == 0:
                if not self.replacement:
                    topk_logit, topk_index = logits.topk(k=step_k, dim=-1)  # batch x k
                    topk_probs = torch.softmax(topk_logit, dim=-1)
                    step_cand_probs = torch.zeros_like(logits)
                    step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs)
                    cand_tokens = topk_index.view(-1, 1)
                else:
                    topk_logit, topk_index = logits.topk(k=1, dim=-1)  # batch x k
                    step_cand_probs = torch.zeros_like(logits)
                    step_cand_probs.scatter_(dim=1, index=topk_index, value=1)
                    cand_tokens = topk_index.view(-1, 1)
                    cand_tokens = torch.repeat_interleave(cand_tokens, step_k, dim=0)
            else:
                step_cand_probs = norm_logits(logits, self.draft_model_temp, self.top_k, self.top_p)
                cand_tokens = torch.multinomial(
                    step_cand_probs,
                    step_k,
                    replacement=self.replacement,
                ).view(-1, 1)

            cand_probs.append(step_cand_probs)

            input_ids = input_ids.repeat_interleave(step_k, dim=0)
            input_ids = torch.cat(
                (
                    input_ids,
                    cand_tokens,
                ),
                dim=1,
            )
            if step + 1 != self.max_draft_len:
                for i in range(len(past_key_values)):
                    past_key_values[i] = (
                        past_key_values[i][0].repeat_interleave(step_k, dim=0),
                        past_key_values[i][1].repeat_interleave(step_k, dim=0),
                    )

        return DecoderOnlyDraftOutput(
            sequences=input_ids,
            past_key_values=past_key_values,
            cand_probs=tuple(cand_probs),
            adapter_past_key_values=adapter_past_key_values,
        )

    def verify(
        self,
        input_ids: torch.LongTensor,
        target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        cand_probs: Optional[Tuple[torch.FloatTensor]],
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyVerificationOutput:
        input_ids = input_ids.to(self.target_model_device)
        batch_size, input_len = input_ids.size()
        if target_model_past_key_values is not None:
            pruned_input_ids = input_ids[
                :, target_model_past_key_values[0][0].size(2) :
            ]
            for i in range(len(target_model_past_key_values)):
                target_model_past_key_values[i] = (
                    target_model_past_key_values[i][0].repeat_interleave(
                        batch_size, dim=0
                    ),
                    target_model_past_key_values[i][1].repeat_interleave(
                        batch_size, dim=0
                    ),
                )
        else:
            pruned_input_ids = input_ids

        outputs: BaseModelOutputWithPast = self.target_model.model(
            input_ids=pruned_input_ids,
            use_cache=True,
            past_key_values=target_model_past_key_values,
            return_dict=True,
            output_attentions=False,
            output_hidden_states=False,
        )

        hidden_states = outputs.last_hidden_state
        target_model_past_key_values = list(outputs.past_key_values)

        logits = self.target_model.lm_head(hidden_states[:, -self.max_draft_len - 1 :])

        if self.target_model_temp == 0:
            _, topk_index = logits.topk(k=1, dim=-1)  # seq_len x 1
            ground_probs = torch.zeros_like(logits)
            ground_probs.scatter_(dim=2, index=topk_index, value=1)
        else:
            ground_probs = norm_logits(logits, self.target_model_temp, self.top_k, self.top_p)


        unverified_input_ids = input_ids[:, -self.max_draft_len :]

        assert ground_probs.size(1) == unverified_input_ids.size(1) + 1

        cand_probs_idx = 0
        alive_group_id = 0

        for depth in range(self.max_draft_len):
            verify_batch_ids = [
                alive_group_id + group_offset * self.reversed_prod_size[depth]
                for group_offset in range(self.k_config[depth])
            ]
            accept_idx_bias = self.acceptance_check(
                ground_probs[alive_group_id, depth],
                cand_probs[depth][cand_probs_idx],
                unverified_input_ids[verify_batch_ids, depth],
            )
            if accept_idx_bias is not None:
                alive_group_id = verify_batch_ids[accept_idx_bias]
                cand_probs_idx = accept_idx_bias + cand_probs_idx * self.k_config[depth]
                if depth == self.max_draft_len - 1:
                    depth = self.max_draft_len
            else:
                break
        input_ids = input_ids[alive_group_id, : input_len - self.max_draft_len + depth]
        endpoint_token = torch.multinomial(
            ground_probs[alive_group_id, depth], num_samples=1
        ).to(device=input_ids.device)

        input_ids = torch.cat((input_ids, endpoint_token))

        input_ids.unsqueeze_(0)

        for i in range(len(target_model_past_key_values)):
            target_model_past_key_values[i] = (
                target_model_past_key_values[i][0][
                    None, alive_group_id, :, : input_len - self.max_draft_len + depth
                ],
                target_model_past_key_values[i][1][
                    None, alive_group_id, :, : input_len - self.max_draft_len + depth
                ],
            )
        for i in range(len(draft_model_past_key_values)):
            draft_model_past_key_values[i] = (
                draft_model_past_key_values[i][0][
                    None,
                    alive_group_id // self.k_config[-1],
                    :,
                    : input_len - self.max_draft_len + depth,
                ],
                draft_model_past_key_values[i][1][
                    None,
                    alive_group_id // self.k_config[-1],
                    :,
                    : input_len - self.max_draft_len + depth,
                ],
            )
        return DecoderOnlyVerificationOutput(
            sequences=input_ids,
            target_model_past_key_values=target_model_past_key_values,
            draft_model_past_key_values=draft_model_past_key_values,
            adapter_past_key_values=adapter_past_key_values,
            acceptance_count=depth,
        )


def get_tree_attn_self_mask(k_config: Tuple[int]):
    k_config = torch.tensor(k_config, dtype=torch.int)
    prod_size = torch.cumprod(k_config, dim=0)
    #print(f'{prod_size=}')
    mask_size = prod_size.sum().item()
    #print(f'{mask_size=}')
    attn_mask = torch.zeros((mask_size, mask_size), dtype=torch.bool)
    attn_mask = attn_mask.diagonal_scatter(torch.ones(mask_size))

    idx_queue = [
        (0, None, idx) for idx in list(range(k_config[0]))
    ]
    while len(idx_queue) != 0:
        depth, parent, idx = idx_queue.pop(0)
        if parent is not None:
            attn_mask[idx, : parent + 1] = attn_mask[parent, : parent + 1]

        if depth != len(k_config) - 1:
            idx_base = prod_size[:depth].sum().item()
            child_idx_base = prod_size[: depth + 1].sum().item()
            for child_idx_bias in range(k_config[depth + 1]):
                real_child_idx = (
                    (idx - idx_base) * k_config[depth + 1]
                    + child_idx_base
                    + child_idx_bias
                )
                idx_queue.append((depth + 1, idx, real_child_idx))
    return attn_mask


class TreeStrategy(Strategy):
    def __init__(
        self,
        draft_model,
        target_model,
        k_config: Tuple[int],
        draft_model_temp: float = 1,
        target_model_temp: float = 1,
        replacement: bool = False,
        speculative_sampling: bool = True,
        top_k: int = 10,
        top_p: float = 0.9,
    ) -> None:
        super().__init__(
            draft_model,
            target_model,
            k_config,
            draft_model_temp,
            target_model_temp,
            replacement,
            speculative_sampling,
            top_k,
            top_p,
        )

        prod_size = torch.cumprod(torch.tensor(k_config, dtype=torch.int), dim=0)
        prod_size = torch.cat((torch.zeros(1).to(prod_size), prod_size)).tolist()
        self.prod_size = prod_size
        self.cumulative_prod_size = torch.cumsum(
            torch.tensor(prod_size), dim=0
        ).tolist()

        self.tree_attn_self_mask = get_tree_attn_self_mask(k_config).to(
            device=self.draft_model_device
        )

    def generate_draft(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyDraftOutput:
        input_ids = input_ids.to(self.draft_model_device)
        cand_probs = []
        step_tree_attn_mask = None
        position_ids = None
        init_input_length = input_ids.size(1)
        if past_key_values is not None:
            pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
        else:
            pruned_input_ids = input_ids
        for step in range(self.max_draft_len):
            step_k = self.k_config[step]

            # prepare attn mask
            if step != 0:
                step_tree_attn_self_mask = self.tree_attn_self_mask[
                    self.cumulative_prod_size[step - 1] : self.cumulative_prod_size[
                        step
                    ],
                    : self.cumulative_prod_size[step],
                ]
                position_ids = torch.full(
                    (1, self.prod_size[step]),
                    init_input_length + step - 1,
                    dtype=torch.long,
                    device=self.draft_model_device,
                )
                context_attn_mask = torch.ones(
                    (self.prod_size[step], init_input_length), dtype=torch.bool
                ).to(self.tree_attn_self_mask)
                step_tree_attn_mask = torch.cat(
                    (context_attn_mask, step_tree_attn_self_mask), dim=1
                )

            outputs: BaseModelOutputWithPast = self.draft_model.model(
                input_ids=pruned_input_ids,
                use_cache=True,
                past_key_values=past_key_values,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
                tree_attn_mask=step_tree_attn_mask,
                position_ids=position_ids,
            )

            hidden_states = outputs.last_hidden_state

            if step == 0:
                hidden_states = hidden_states[0, -1:]
            else:
                hidden_states = hidden_states[0]

            logits = self.draft_model.lm_head(hidden_states)  # seq_len x hidden_dim

            past_key_values = list(outputs.past_key_values)

            if self.draft_model_temp == 0:
                if not self.replacement:
                    topk_logit, topk_index = logits.topk(
                        k=step_k, dim=-1
                    )  # seq_len x k
                    topk_probs = torch.softmax(topk_logit, dim=-1)
                    step_cand_probs = torch.zeros_like(logits)
                    step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs)
                    cand_tokens = topk_index.view(1, -1)
                else:
                    topk_logit, topk_index = logits.topk(k=1, dim=-1)  # seq_len x k
                    step_cand_probs = torch.zeros_like(logits)
                    step_cand_probs.scatter_(dim=1, index=topk_index, value=1)
                    cand_tokens = topk_index.view(1, -1)
                    cand_tokens = torch.repeat_interleave(cand_tokens, step_k, dim=1)
            else:
                step_cand_probs = norm_logits(logits, self.draft_model_temp, self.top_k, self.top_p)

                cand_tokens = torch.multinomial(
                    step_cand_probs, step_k, replacement=self.replacement
                ).view(1, -1)
            cand_probs.append(step_cand_probs)

            pruned_input_ids = cand_tokens

            input_ids = torch.cat((input_ids, pruned_input_ids), dim=1)

        return DecoderOnlyDraftOutput(
            sequences=input_ids,
            past_key_values=past_key_values,
            cand_probs=tuple(cand_probs),
            adapter_past_key_values=adapter_past_key_values,
        )

    def _forward_target_model(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
    ):
        input_ids = input_ids.to(self.target_model_device)
        tree_attn_len = self.tree_attn_self_mask.size(0)
        init_input_length = input_ids.size(1) - tree_attn_len
        init_forward = False

        if past_key_values is not None:
            pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
        else:
            pruned_input_ids = input_ids
            init_forward = True

        if init_forward:
            tree_attn_mask = torch.zeros(
                (input_ids.size(1), input_ids.size(1)),
                dtype=torch.bool,
                device=self.target_model_device,
            )
            mask_cond = torch.arange(
                tree_attn_mask.size(-1), device=self.target_model_device
            )
            tree_attn_mask.masked_fill_(
                mask_cond < (mask_cond + 1).view(tree_attn_mask.size(-1), 1), 1
            )
            tree_attn_mask[-tree_attn_len:, -tree_attn_len:] = self.tree_attn_self_mask
            position_ids = tree_attn_mask.sum(dim=1) - 1

        else:
            tree_attn_mask = torch.ones(
                (
                    tree_attn_len + 1,
                    input_ids.size(1),
                ),
                dtype=torch.bool,
                device=self.target_model_device,
            )

            tree_attn_mask[1:, init_input_length:] = self.tree_attn_self_mask
            tree_attn_mask[0, init_input_length:] = 0
            position_ids = tree_attn_mask.sum(dim=1) - 1

        outputs: BaseModelOutputWithPast = self.target_model.model(
            input_ids=pruned_input_ids,
            use_cache=True,
            past_key_values=past_key_values,
            return_dict=True,
            output_attentions=False,
            output_hidden_states=False,
            tree_attn_mask=tree_attn_mask,
            position_ids=position_ids,
        )
        hidden_states = outputs.last_hidden_state
        past_key_values = list(outputs.past_key_values)

        logits = self.target_model.lm_head(
            hidden_states[:, -tree_attn_len - 1 :]
        )  # 1 x seq_len x hidden_dim
        return logits, past_key_values

    def verify(
        self,
        input_ids: torch.LongTensor,
        target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
        cand_probs: Optional[Tuple[torch.FloatTensor]],
        adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyVerificationOutput:
        input_ids = input_ids.to(self.target_model_device)
        logits, target_model_past_key_values = self._forward_target_model(
            input_ids, target_model_past_key_values
        )
        logits = logits[0]  # seq_len x hidden_dim
        tree_attn_len = self.tree_attn_self_mask.size(0)
        unverified_tokens = input_ids[0, -tree_attn_len:]
        init_input_length = input_ids.size(1) - tree_attn_len

        if self.target_model_temp == 0:
            _, topk_index = logits.topk(k=1, dim=-1)  # seq_len x 1
            ground_probs = torch.zeros_like(logits)
            ground_probs.scatter_(dim=1, index=topk_index, value=1)
        else:
            ground_probs = norm_logits(logits, self.target_model_temp, self.top_k, self.top_p)

        current_ground_prob = ground_probs[0]
        ground_probs = ground_probs[1:]

        keep_indices = list(range(init_input_length))
        to_drop_len = 0
        idx_group_bias = 0
        cand_probs_idx = 0

        for depth in range(self.max_draft_len):
            idx_base = self.cumulative_prod_size[depth] + idx_group_bias
            accept_idx_bias = self.acceptance_check(
                current_ground_prob,
                cand_probs[depth][cand_probs_idx],
                unverified_tokens[idx_base : idx_base + self.k_config[depth]],
            )

            if accept_idx_bias is not None:
                global_idx = idx_base + accept_idx_bias
                current_ground_prob = ground_probs[global_idx]
                keep_indices.append(init_input_length + global_idx)
                if depth == self.max_draft_len - 1:
                    to_drop_len += 1
                    depth = self.max_draft_len
                else:
                    cand_probs_idx = idx_group_bias + accept_idx_bias
                    idx_group_bias = cand_probs_idx * self.k_config[depth + 1]

            else:
                break

        keep_indices = torch.tensor(
            keep_indices, dtype=torch.long, device=self.target_model_device
        )
        if to_drop_len != 0:
            draft_keep_indices = keep_indices[: len(keep_indices) - to_drop_len]
        else:
            draft_keep_indices = keep_indices

        tail_ground_token = torch.multinomial(current_ground_prob, num_samples=1).to(
            device=input_ids.device
        )

        input_ids = input_ids.index_select(dim=1, index=keep_indices)
        input_ids = torch.cat((input_ids, tail_ground_token[None]), dim=1)

        for i in range(len(target_model_past_key_values)):
            keep_indices = keep_indices.to(
                device=target_model_past_key_values[i][0].device
            )
            target_model_past_key_values[i] = (
                target_model_past_key_values[i][0].index_select(
                    dim=2, index=keep_indices
                ),
                target_model_past_key_values[i][1].index_select(
                    dim=2, index=keep_indices
                ),
            )

        for i in range(len(draft_model_past_key_values)):
            draft_model_past_key_values[i] = (
                draft_model_past_key_values[i][0].index_select(
                    dim=2, index=draft_keep_indices
                ),
                draft_model_past_key_values[i][1].index_select(
                    dim=2, index=draft_keep_indices
                ),
            )

        return DecoderOnlyVerificationOutput(
            sequences=input_ids,
            target_model_past_key_values=target_model_past_key_values,
            draft_model_past_key_values=draft_model_past_key_values,
            adapter_past_key_values=adapter_past_key_values,
            acceptance_count=depth,
        )

class WETAPStrategy(Strategy):
    def __init__(
            self,
            draft_model,
            target_model,
            k_config: Tuple[int],
            beam_width: int = 64,
            max_budget : int = 64,
            accept_thres: float = 0.5,
            trim_type = 'prob',
            dynwidth : bool = False,
            width_list = None,
            draft_model_temp=1,
            target_model_temp=1,
            replacement: bool = False,
            speculative_sampling: bool = True,
            top_k=10,
            top_p=0.9,
    ) -> None:

        super().__init__(
            draft_model,
            target_model,
            k_config,
            draft_model_temp,
            target_model_temp,
            replacement,
            speculative_sampling,
            top_k,
            top_p,
        )

        self.dynwidth = dynwidth
        self.width_list = width_list
        self.beam_width = beam_width


        if not self.dynwidth:
            self.num_token_per_iter = beam_width * self.max_draft_len
        else:
            if self.width_list is not None:
                self.num_token_per_iter = max(self.width_list) * self.max_draft_len
            else:
                raise ValueError('width range must be provided!')

        self.max_budget = max_budget
        self.log_accept_thres = math.log(accept_thres)
        self.trim_type = trim_type
        print(f'Trim Type is {self.trim_type}')


        if target_model_temp == 0:
            warnings.warn(
                (
                    "For WETAP, the target model temperature shouldn't be 0, there is no performance improvement"
                ),
                category=UserWarning,
                stacklevel=3,
            )

    def generate_draft(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
            adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        ) -> DecoderOnlyDraftOutput:

        input_ids = input_ids.to(self.draft_model_device)
        cand_probs = None
        log_beam_prob = None
        cand_pos = None
        input_len = input_ids.size(1)
        tree_att_mask = torch.full(
            (
                self.num_token_per_iter,
                self.num_token_per_iter + input_len,
            ),
            True,
        )

        tree_att_mask[:, input_len:] = False
        step_tree_att_mask = None
        position_ids = None
        self.beam_idx_history = []
        self.parent_idx_history = []
        self.draft_p_history = []
        self.quota_history = []
        self.entropy_per_layer = []
        cand_indices = None

        for step in range(self.max_draft_len):

            if past_key_values is not None:

                pruned_input_ids = input_ids[:, past_key_values[0][0].size(2):]

            else:
                pruned_input_ids = input_ids


            bias = self.beam_width * step

            if step > 0:

                # prepare special attention mask and position_ids
                step_tree_att_mask = tree_att_mask[
                                        bias - self.beam_width:bias,
                                        :input_len + bias
                                        ].to(self.draft_model_device)

                position_ids = torch.full(
                        (
                                1, self.beam_width,
                            ),
                        input_len + step - 1,
                        dtype=torch.long,
                        device=self.draft_model_device,
                )


            outputs: BaseModelOutputWithPast = self.draft_model.model(
                input_ids=pruned_input_ids,
                use_cache=True,
                past_key_values=past_key_values,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=True,
                tree_attn_mask=step_tree_att_mask,
                position_ids=position_ids,
            )

            if cand_indices is None:
                hidden_states = outputs.last_hidden_state
            else:
                hidden_states = outputs.last_hidden_state[:,cand_indices,:]

            past_key_values = list(outputs.past_key_values)


            if log_beam_prob is None:


                logits = self.draft_model.lm_head(hidden_states[:, -1])  # (1,vocab_size)

                # compute beam probability
                # do sampling
                if self.draft_model_temp > 0:
                    step_cand_probs = norm_logits(logits, self.draft_model_temp, 0, 0)  # (1,vocab_size)

                    # print(f'{torch.where(step_cand_probs[0] != 0)[0]=}')
                    cand_tokens = torch.multinomial(
                        step_cand_probs,
                        self.beam_width,
                        replacement=False,  # self.replacement,
                    ).view(-1,1)

                    #cand_tokens = cand_tokens[step_cand_probs[0][cand_tokens].argsort(descending=True)].view(-1, 1)
                    #print(f'{cand_tokens.shahpe=}')

                    #cand_tokens = cand_tokens[:,torch.argsort(step_cand_probs[:, cand_tokens].view(-1), descending=True)]
                    #cand_tokens = cand_tokens.view(-1, 1)


                else:

                    topk_logit, topk_index = logits.topk(
                        k=self.beam_width, dim=-1
                    )  # 1,k
                    topk_probs = torch.softmax(topk_logit, dim=-1)

                    step_cand_probs = torch.zeros_like(logits) #(1,vocab_size)
                    step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs)
                    cand_tokens = topk_index.view(-1, 1)  # (branch_width,1)


                #print(f'{cand_tokens=}')
                # log_beam_prob = step_cand_probs[:,cand_tokens].view(self.beam_width,1).log_softmax(dim=0)
                log_beam_prob = step_cand_probs[:, cand_tokens].view(self.beam_width, 1).log()

                #log_beam_prob = (log_beam_prob / log_beam_prob.sum()).log()

                cand_probs = step_cand_probs[:, cand_tokens].view(1, -1).log()
                cand_pos = torch.full_like(cand_probs, (step+1)/self.max_draft_len )
                prob = log_beam_prob.squeeze(-1).exp()
                self.quota_history.append(prob / prob.norm(p=1))
                #self.quota_history.append(torch.softmax(log_beam_prob.squeeze(-1).exp(), dim=-1))
                self.entropy_per_layer.append(max(0.0, min(1.0, entropy(self.quota_history[-1]).item() / math.log(16))))
                self.parent_idx_history.append( -torch.ones( self.beam_width, device=cand_probs.device) )


                input_ids = torch.cat(
                    (
                        input_ids,
                        cand_tokens.view(1, -1),
                    ),
                    dim=1,
                )

                # update tree attention
                tree_att_mask[:self.beam_width, input_len:input_len + self.beam_width] = torch.eye(self.beam_width,dtype=torch.bool)


            else:

                #logits = self.draft_model.lm_head(hidden_states[:, -self.beam_width:]).view(self.beam_width, -1)  #（width,vocab_size)

                logits = self.draft_model.lm_head(hidden_states[:, -self.beam_width:]).view(self.beam_width, -1)  #（width,vocab_size)


                # do sampling
                if self.draft_model_temp > 0:
                    # compute beam probability

                    step_cand_probs = torch.log(norm_logits(logits, self.draft_model_temp, 0, 0))  #width,vocab_size

                    vocab_size = step_cand_probs.size(-1)

                    beam_probs = (log_beam_prob + step_cand_probs).view(-1).exp()  # (width * vocab,)


                    #print(f'{beam_probs.sum(dim=-1)=}')

                    cand_beams = torch.multinomial(
                        beam_probs,
                        self.beam_width,
                        replacement=False, #self.replacement,
                    ).view(-1,1)

                    #cand_beams = cand_beams[beam_probs[cand_beams].argsort(descending=True)].view(-1,1)

                    #cand_beams = cand_beams[torch.argsort(beam_probs[cand_beams], descending=True)]
                    #cand_beams = cand_beams.view(-1, 1)  # reshape to (beam_width, 1)


                else:
                    step_cand_probs = torch.log_softmax(logits, dim=-1)  #(width,vocab_size)
                    vocab_size = step_cand_probs.size(-1)
                    if cand_indices is None:
                        beam_probs = (log_beam_prob + step_cand_probs).view(-1).exp()  #(width * vocab,)
                    else:
                        beam_probs = (log_beam_prob[cand_indices,:] + step_cand_probs).view(-1).exp()

                    topk_probs, topk_index = beam_probs.topk(
                        k=self.beam_width, dim=-1
                    )  # 1,k
                    #topk_probs = torch.softmax(topk_logit, dim=-1)

                    beam_probs = torch.zeros_like(beam_probs)
                    beam_probs.scatter_(dim=0, index=topk_index, src=topk_probs)
                    cand_beams = topk_index.view(-1, 1)  #(width,1)


                log_beam_prob = beam_probs[cand_beams].log()

                #log_beam_prob = (log_beam_prob / log_beam_prob.sum()).log()

                # modify input_ids based on sampling results
                beam_idx = torch.div(cand_beams, vocab_size, rounding_mode='floor').long().view(-1)
                self.beam_idx_history.append(beam_idx)
                prob = log_beam_prob.squeeze(-1).exp()

                self.quota_history.append(prob / prob.norm(p=1))
                #self.quota_history.append(torch.softmax(log_beam_prob.squeeze(-1).exp(),dim=-1) )

                self.entropy_per_layer.append(  max(0.0, min(1.0, entropy(self.quota_history[-1]).item() / math.log(self.beam_width))) )
                self.parent_idx_history.append(beam_idx + self.beam_width * (step - 1))


                tokens = cand_beams % vocab_size

                log_token_probs = step_cand_probs.view(-1)
                log_sampled_token_probs = log_token_probs[cand_beams]

                cur_cand_probs = log_sampled_token_probs.view(-1) + cand_probs[-1, beam_idx]

                cand_probs = torch.cat((cand_probs, cur_cand_probs.view(1, -1)), dim=0)

                cur_cand_pos = torch.full_like(cur_cand_probs, (step+1)/self.max_draft_len )
                cand_pos = torch.cat((cand_pos, cur_cand_pos.view(1, -1)), dim=0)

                # update tree attention
                tree_att_mask[bias:self.beam_width + bias] = tree_att_mask[bias - self.beam_width: bias][beam_idx.cpu()]
                tree_att_mask[bias:self.beam_width + bias,bias + input_len:bias + input_len + self.beam_width] = torch.eye(self.beam_width, dtype=torch.bool)

                input_ids = torch.cat(
                    (
                        input_ids,
                        tokens.view(1, -1),
                    ),
                    dim=1,
                )

            assert step == cand_probs.shape[0] - 1

        if  input_ids.shape[1] - input_len > self.max_budget:

            source_idxs, dest_idxs = self.trim_budget(cand_probs,cand_pos,input_ids,input_len,tree_att_mask)

            cand_probs, input_ids, tree_att_mask, past_key_values = self.reset_tree(input_ids,input_len,cand_probs,
                                                                                    past_key_values,
                                                                                    tree_att_mask,source_idxs,dest_idxs)
                                                                                                             
        assert input_ids.shape[1] - input_len == min(self.max_budget, self.num_token_per_iter)                                                                                                     


        self.tree_att_mask = tree_att_mask
        return DecoderOnlyDraftOutput(
            sequences=input_ids,
            past_key_values=past_key_values,
            cand_probs=tuple(cand_probs),
            adapter_past_key_values=adapter_past_key_values,
        )

    #Construct the token tree with dynamic width
    def generate_draft_dynwidth(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
            width_list: Optional[List[int]] = None,
            adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        ) -> DecoderOnlyDraftOutput:

        input_ids = input_ids.to(self.draft_model_device)
        cand_probs = None
        log_beam_prob = None
        cand_pos = None
        input_len = input_ids.size(1)
        tree_att_mask = torch.full(
            (
                self.num_token_per_iter,
                self.num_token_per_iter + input_len,
            ),
            True,
        )

        tree_att_mask[:, input_len:] = False
        step_tree_att_mask = None
        position_ids = None
        self.beam_idx_history = []
        self.parent_idx_history = []
        self.draft_p_history = []
        self.quota_history = []
        self.entropy_per_layer = []
        cand_indices = None
        width_per_layer = []
        padlen_per_layer = []
        for step in range(self.max_draft_len):
            #print(f'{step=}')
            if past_key_values is not None:

                pruned_input_ids = pruned_input_ids = input_ids[:, past_key_values[0][0].size(2):]

            else:
                pruned_input_ids = input_ids


            if width_per_layer is None:
                bias = 0
            else:
                bias = sum(width_per_layer)

            if step > 0:


                step_tree_att_mask = tree_att_mask[
                                        bias - width_per_layer[-1]:bias ,
                                        :input_len + bias
                                        ].to(self.draft_model_device)

                position_ids = torch.full(
                        (
                                1, width_per_layer[-1],
                            ),
                        input_len + step - 1,
                        dtype=torch.long,
                        device=self.draft_model_device,
                )


            "decide the width by entropy"
            if step == 0:
                width = self.width_list[0]
                width_per_layer.append(width)
            else:

                #width = pick_width_by_entropy_bucket(self.entropy_per_layer[-1], width_per_layer[-1], self.width_list)
                #width = pick_width_by_entropy_linear(self.entropy_per_layer[-1], width_per_layer[-1], self.width_list)
                width = pick_width_by_entropy_continuous(self.entropy_per_layer[-1], width_per_layer[-1], self.width_list, 1.2)
                # #width = pick_width_by_entropy_neff(self.entropy_per_layer[-1], width_per_layer[-1], self.width_list)
                #width = random.choice(self.width_list)
                width_per_layer.append(width)


            outputs: BaseModelOutputWithPast = self.draft_model.model(
                input_ids=pruned_input_ids,
                use_cache=True,
                past_key_values=past_key_values,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=True,
                tree_attn_mask=step_tree_att_mask,
                position_ids=position_ids,
            )

            hidden_states = outputs.last_hidden_state



            past_key_values = list(outputs.past_key_values)
            #print(f'{past_key_values[0][0].shape}')

            if log_beam_prob is None:

                logits = self.draft_model.lm_head(hidden_states[:, -1])  # (1,vocab_size)


                if self.draft_model_temp > 0:
                    step_cand_probs = norm_logits(logits, self.draft_model_temp, 0, 0)  # 1,vocab_size


                    cand_tokens = torch.multinomial(
                        step_cand_probs,
                        width,
                        replacement=False,
                    ).view(-1,1)  #(width,1)

                else:


                    topk_logit, topk_index = logits.topk(
                        k=width, dim=-1
                    )  # 1,k
                    topk_probs = torch.softmax(topk_logit, dim=-1)

                    step_cand_probs = torch.zeros_like(logits)   # (1,vocab_size)
                    step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs)
                    cand_tokens = topk_index.view(-1, 1)


                #print(f'{cand_tokens.shape=}')
                #print(f'{cand_tokens=}')

                log_beam_prob = step_cand_probs[:, cand_tokens].view(width, 1).log()

                prob = log_beam_prob.squeeze(-1).exp()
                self.quota_history.append( prob/ prob.norm(p=1) )

                'token entropy'
                self.entropy_per_layer.append( entropy(self.quota_history[-1]).item() )

                #log_beam_prob = (log_beam_prob / log_beam_prob.sum()).log()

                'Padding the probability and Length'
                cand_probs = step_cand_probs[:, cand_tokens].view(-1).log()
                cand_pos = torch.full_like(cand_probs, (step + 1) / self.max_draft_len)
                padlen = max(self.width_list) - cand_probs.shape[0]
                padlen_per_layer.append(padlen)
                cand_probs = F.pad(cand_probs, (0, padlen ), value=-1000).view(1, -1)
                cand_pos = F.pad(cand_pos, (0, padlen), value=0).view(1, -1)

                parent_idx = -torch.ones( width, device=cand_probs.device)
                parent_idx = F.pad(parent_idx, (0, padlen), value=-100)
                self.parent_idx_history.append(parent_idx)


                # modify input_ids based on sampling results
                input_ids = torch.cat(
                    (
                        input_ids,
                        cand_tokens.view(1, -1), #形状为(1,8)
                    ),
                    dim=1,
                )

                # update tree attention
                tree_att_mask[:width_per_layer[-1], input_len:input_len + width_per_layer[-1]] = torch.eye(width,dtype=torch.bool)


            else:


                logits = self.draft_model.lm_head(hidden_states[:, -width_per_layer[-2]:]).view(width_per_layer[-2], -1)  #（width,vocab_size)


                if self.draft_model_temp > 0:


                    step_cand_probs = torch.log(norm_logits(logits, self.draft_model_temp, 0, 0))  #width,vocab_size

                    vocab_size = step_cand_probs.size(-1)

                    beam_probs = (log_beam_prob + step_cand_probs).view(-1).exp()  # (width * vocab,)


                    #print(f'{beam_probs.sum(dim=-1)=}')

                    cand_beams = torch.multinomial(
                        beam_probs,
                        width_per_layer[-1],
                        replacement=False, #self.replacement,
                    ).view(-1,1)

                else:
                    step_cand_probs = torch.log_softmax(logits, dim=-1)  #(width,vocab_size)
                    vocab_size = step_cand_probs.size(-1)
                    if cand_indices is None:
                        beam_probs = (log_beam_prob + step_cand_probs).view(-1).exp()  #(width * vocab,)
                    else:
                        beam_probs = (log_beam_prob[cand_indices,:] + step_cand_probs).view(-1).exp()

                    topk_probs, topk_index = beam_probs.topk(
                        k=self.beam_width, dim=-1
                    )  # 1,k
                    #topk_probs = torch.softmax(topk_logit, dim=-1)

                    beam_probs = torch.zeros_like(beam_probs)
                    beam_probs.scatter_(dim=0, index=topk_index, src=topk_probs)
                    cand_beams = topk_index.view(-1, 1)


                log_beam_prob = beam_probs[cand_beams].log()

                # modify input_ids based on sampling results
                beam_idx = torch.div(cand_beams, vocab_size, rounding_mode='floor').long().view(-1)
                self.beam_idx_history.append(beam_idx)


                #self.quota_history.append(torch.softmax(log_beam_prob.squeeze(-1).exp(),dim=-1) )
                prob = log_beam_prob.squeeze(-1).exp()
                self.quota_history.append(prob / prob.norm(p=1))
                self.entropy_per_layer.append( entropy(self.quota_history[-1]).item() )

                tokens = cand_beams % vocab_size
                log_token_probs = step_cand_probs.view(-1)
                log_sampled_token_probs = log_token_probs[cand_beams]


                cur_cand_probs = log_sampled_token_probs.view(-1) + cand_probs[-1, beam_idx]
                cur_cand_pos = torch.full_like(cur_cand_probs, (step+1)/self.max_draft_len )

                'padding probs and positions'
                padlen = max(self.width_list) - cur_cand_probs.shape[0]
                padlen_per_layer.append(padlen)

                cur_cand_probs = F.pad(cur_cand_probs, (0, padlen), value=-1000)
                cand_probs = torch.cat((cand_probs, cur_cand_probs.view(1, -1)), dim=0)

                cur_cand_pos = F.pad(cur_cand_pos, (0, padlen), value=0 )
                cand_pos = torch.cat((cand_pos, cur_cand_pos.view(1, -1)), dim=0)

                parent_idx = beam_idx + sum(width_per_layer[:step-1])
                parent_idx = F.pad(parent_idx, (0, padlen), value=-100)
                self.parent_idx_history.append(parent_idx)

                # update tree attention
                tree_att_mask[bias:width_per_layer[-1] + bias] = tree_att_mask[bias - width_per_layer[-2]: bias][
                    beam_idx.cpu()]
                tree_att_mask[bias:width_per_layer[-1] + bias, bias + input_len:bias + input_len + width_per_layer[-1]] = torch.eye(width, dtype=torch.bool)

                input_ids = torch.cat(
                    (
                        input_ids,
                        tokens.view(1, -1),
                    ),
                    dim=1,
                )

            assert step == cand_probs.shape[0] - 1



        'If more than budget，prune the tree'
        if  input_ids.shape[1] - input_len > self.max_budget:

            source_idxs, source_idxs_padded, dest_idxs = self.trim_budget_dynwidth(cand_probs,cand_pos,input_ids,input_len,tree_att_mask,
                                                                                   width_per_layer, padlen_per_layer)

            cand_probs, input_ids, tree_att_mask, past_key_values = self.reset_tree_dynwidth(input_ids,input_len,cand_probs,
                                                                                             past_key_values,
                                                                                             tree_att_mask,source_idxs, source_idxs_padded, dest_idxs,
                                                                                              width_per_layer)

        assert input_ids.shape[1] - input_len == min(self.max_budget, self.num_token_per_iter)


        self.tree_att_mask = tree_att_mask
        return DecoderOnlyDraftOutput(
            sequences=input_ids,
            past_key_values=past_key_values,
            cand_probs=tuple(cand_probs),
            adapter_past_key_values=adapter_past_key_values,
        )

    def trim_budget(
            self,
            cand_probs,
            cand_pos,
            input_ids,
            input_len,
            tree_att_mask,
            ):

        cand_probs_flat = cand_probs[cand_probs != -1000].view(-1)

        cand_pos_flat = cand_pos[cand_pos != 0].view(-1)
        parent_idxs = torch.stack(self.parent_idx_history,dim=0).view(-1)
        parent_idxs = parent_idxs[parent_idxs != -100]

        #weight coefficient of length
        alpha = 0.6
        assert cand_probs_flat.shape[0] == input_ids.shape[1] - input_len == cand_pos_flat.shape[0]

        if self.trim_type == 'prob' :
            source_idxs = cand_probs_flat.topk(self.max_budget, sorted=False).indices
            source_idxs, _ = source_idxs.sort()

        else:

            #Calculate score considering both probs and lens
            probs = cand_probs_flat.exp()  # 转回概率
            probs = (probs - probs.min()) / (probs.max() - probs.min() + 1e-9)
            score = alpha * probs + (1-alpha) * cand_pos_flat

            source_idxs = score.topk(self.max_budget, sorted=False).indices
            source_idxs, _ = source_idxs.sort()  #.view(-1,self.beam_width)

            #Find tokens whose parents are not chosen
            source_idxs_exceptfirst = source_idxs[source_idxs >= self.beam_width ]
            source_idxs_parent = parent_idxs[source_idxs_exceptfirst]
            noparent_source_idxs = source_idxs_exceptfirst[ ~torch.isin(source_idxs_parent, source_idxs) ]


            if noparent_source_idxs.numel() > 0:
                #Add not chosen parents into source
                noparent_idx = torch.nonzero(tree_att_mask[noparent_source_idxs.to(tree_att_mask.device),input_len:], as_tuple=False)
                col_indices = noparent_idx[:, 1].view(-1).to(source_idxs.device)
                source_idxs_new = torch.unique( torch.cat( (source_idxs, col_indices)))

                #Now source's total number > budget,Trim again
                source_idxs_new_parent = parent_idxs[ source_idxs_new[ source_idxs_new >= self.beam_width]]
                not_parent = source_idxs_new[~torch.isin(source_idxs_new, source_idxs_new_parent) ]
                last_pos_idxs = torch.where(cand_pos_flat == 1)[0]
                not_parent = not_parent[~torch.isin(not_parent, last_pos_idxs)]

                #First trim tokens which are not parents and not in the last layer so that not influence the completeness of the tree
                if not_parent.numel() >= source_idxs_new.numel() - self.max_budget:
                    source_idxs = source_idxs_new[ ~torch.isin( source_idxs_new, not_parent[: source_idxs_new.numel()-self.max_budget ] ) ]
                else:
                    #If not enough,Then trim tokens with low cum_log_prob
                    #Because log_prob is accumulated,so that child's cum_log_prob must be lower than parent's, So trimming low prob not influences completeness
                    source_idxs_new = source_idxs_new[~torch.isin(source_idxs_new, not_parent)]
                    top_indices_new = torch.topk(cand_probs_flat[source_idxs_new], self.max_budget).indices
                    source_idxs = source_idxs_new[ torch.sort(top_indices_new).values ]

            assert torch.isin( parent_idxs[ source_idxs[ source_idxs >= self.beam_width]], source_idxs).all()


        dest_idxs = torch.arange(source_idxs.shape[-1], device=source_idxs.device)

        # print(f'{source_idxs=}')
        assert source_idxs.shape[-1] == dest_idxs.shape[-1] == self.max_budget
        assert dest_idxs.max().item() <= self.max_budget
        return source_idxs, dest_idxs


    def reset_tree(
            self,
            input_ids, input_len,
            cand_probs,
            past_key_values,
            tree_att_mask,
            source_idxs,
            dest_idxs
            ):


        # update candidate probs, only select top-budget
        cand_probs_flat = cand_probs.view(-1)
        mask = torch.zeros_like(cand_probs_flat, dtype=torch.bool)
        mask[source_idxs] = True
        mask = mask.view(-1, self.beam_width)
        cand_probs[mask & (cand_probs == float("-inf"))] = -250
        cand_probs = torch.where(mask, cand_probs, torch.full_like(cand_probs, float('-inf')))


        # update past key&values for draft and adapter
        select_idx = torch.zeros(input_ids.shape[-1] - self.beam_width, dtype=torch.bool)
        select_idx[:input_len] = True
        draft_source_idxs = source_idxs[source_idxs < (self.max_draft_len - 1) * self.beam_width]
        select_idx[input_len:][draft_source_idxs] = True
        for i in range(len(past_key_values)):
            past_key_values[i] = (
                past_key_values[i][0][:, :, select_idx],
                past_key_values[i][1][:, :, select_idx],
            )

        # update tokens
        input_ids = torch.cat(
            (
                input_ids[:, :input_len],
                input_ids[:, input_len:][:,source_idxs]
            ),
            dim=1
        )

        # update tree attention mask,only select top-budget tokens' attention mask
        source_idxs = source_idxs.to(tree_att_mask.device)
        #print(f'修剪的source：{source_idxs=}')
        dest_idxs = dest_idxs.to(tree_att_mask.device)
        tree_att_mask[dest_idxs, :] = tree_att_mask[source_idxs, :]
        tree_att_mask[:,input_len + dest_idxs] = tree_att_mask[:,input_len + source_idxs]
        tree_att_mask = tree_att_mask[:self.max_budget, :self.max_budget + input_len]

        # update beam idx and parent idx
        cut_idx = -1
        for idx in range(len(self.beam_idx_history)):
            mask_per_layer = torch.where(mask[idx+1])[0]
            if torch.isfinite(mask_per_layer).any():
                self.beam_idx_history[idx] = self.beam_idx_history[idx][mask_per_layer]
                max_position = torch.where(mask[idx])[0].max().item()
                assert ( mask[idx, max_position+1:] == 0 ).all()
                self.beam_idx_history[idx] = self.beam_idx_history[idx][self.beam_idx_history[idx] <= max_position]
            else:
                cut_idx = idx
                break

        if cut_idx >= 0:
            self.beam_idx_history = self.beam_idx_history[:cut_idx]

        return cand_probs, input_ids, tree_att_mask, past_key_values


    def trim_budget_dynwidth(
            self,
            cand_probs,
            cand_pos,
            input_ids,
            input_len,
            tree_att_mask,
            width_per_layer,
            padlen_per_layer,
            ):
        
        cand_probs_flat = cand_probs[cand_probs != -1000].view(-1)

        cand_pos_flat = cand_pos[cand_pos != 0].view(-1)

        parent_idxs = torch.stack(self.parent_idx_history,dim=0).view(-1)
        parent_idxs = parent_idxs[parent_idxs != -100]


        alpha = 0.6
        assert cand_probs_flat.shape[0] == input_ids.shape[1] - input_len == cand_pos_flat.shape[0]

        if self.trim_type == 'prob' :
            source_idxs = cand_probs_flat.topk(self.max_budget, sorted=False).indices
            source_idxs, _ = source_idxs.sort()


        else:
            if self.trim_type == 'prob&len' :
                #Calculate score considering both probs and lens
                probs = cand_probs_flat.exp()
                probs = (probs - probs.min()) / (probs.max() - probs.min() + 1e-9)
                score = alpha * probs + (1-alpha) * cand_pos_flat
            else:
                probs = cand_probs_flat.exp()
                probs = (probs - probs.min()) / (probs.max() - probs.min() + 1e-9)
                score = alpha * probs + (1 - alpha) * cand_pos_flat

            source_idxs = score.topk(self.max_budget, sorted=False).indices
            source_idxs, _ = source_idxs.sort()

            #Find tokens whose parents are not chosen
            source_idxs_exceptfirst = source_idxs[source_idxs >= width_per_layer[0] ]
            source_idxs_parent = parent_idxs[source_idxs_exceptfirst]
            noparent_source_idxs = source_idxs_exceptfirst[ ~torch.isin(source_idxs_parent, source_idxs) ]


            if noparent_source_idxs.numel() > 0:
                #Add not chosen parents into source
                noparent_idx = torch.nonzero(tree_att_mask[noparent_source_idxs.to(tree_att_mask.device),input_len:], as_tuple=False)
                col_indices = noparent_idx[:, 1].view(-1).to(source_idxs.device)
                source_idxs_new = torch.unique( torch.cat( (source_idxs, col_indices)))

                #Now source's total number > budget,Trim again
                source_idxs_new_parent = parent_idxs[ source_idxs_new[ source_idxs_new >= width_per_layer[0]]]
                leaf_idxs = source_idxs_new[~torch.isin(source_idxs_new, source_idxs_new_parent) ]
                last_pos_idxs = torch.where(cand_pos_flat == 1)[0]
                not_parent = leaf_idxs[~torch.isin(leaf_idxs, last_pos_idxs)]

                #First trim tokens which are not parents and not in the last layer so that not influence the completeness of the tree
                if not_parent.numel() >= source_idxs_new.numel() - self.max_budget:
                    source_idxs = source_idxs_new[ ~torch.isin( source_idxs_new, not_parent[: source_idxs_new.numel()-self.max_budget ] ) ]
                else:
                    #If not enough,Then trim tokens with low cum_log_prob
                    #Because log_prob is accumulated,so that child's cum_log_prob must be lower than parent's, So trimming low prob not influences completeness
                    source_idxs_new = source_idxs_new[~torch.isin(source_idxs_new, not_parent)]
                    top_indices_new = torch.topk(cand_probs_flat[source_idxs_new], self.max_budget).indices
                    source_idxs = source_idxs_new[ torch.sort(top_indices_new).values ]

            assert torch.isin( parent_idxs[ source_idxs[ source_idxs >= self.beam_width]], source_idxs).all()

        boundaries = torch.cumsum( torch.tensor(width_per_layer), dim=-1).to(source_idxs.device)
        idx_right = torch.bucketize(source_idxs, boundaries, right=True)
        padlen_per_layer = [0] + padlen_per_layer[:self.max_draft_len-1]
        padlen_per_layer = torch.cumsum( torch.tensor(padlen_per_layer), dim=-1).to(source_idxs.device)

        source_idxs_padded = source_idxs + padlen_per_layer[idx_right]

        dest_idxs = torch.arange(source_idxs.shape[-1], device=source_idxs.device)

        # print(f'{source_idxs=}')
        assert source_idxs.shape[-1] == dest_idxs.shape[-1] == self.max_budget
        assert dest_idxs.max().item() <= self.max_budget
        return source_idxs, source_idxs_padded, dest_idxs

    def reset_tree_dynwidth(
            self,
            input_ids, input_len,
            cand_probs,
            past_key_values,
            tree_att_mask,
            source_idxs,
            source_idxs_padded,
            dest_idxs,
            width_per_layer,
            ):

        # update candidate probs, only select top-budget
        cand_probs_flat = cand_probs.view(-1)

        mask = torch.zeros_like(cand_probs_flat, dtype=torch.bool)
        mask[source_idxs_padded] = True
        mask = mask.view(-1, max(self.width_list))
        cand_probs[mask & (cand_probs == float("-inf"))] = -250
        cand_probs = torch.where(mask, cand_probs, torch.full_like(cand_probs, float('-inf')))


        # update past key&values for draft and adapter
        select_idx = torch.zeros(input_ids.shape[-1] - width_per_layer[-1], dtype=torch.bool)
        select_idx[:input_len] = True
        draft_source_idxs = source_idxs[source_idxs < sum(width_per_layer[:-1])]
        select_idx[input_len:][draft_source_idxs] = True

        for i in range(len(past_key_values)):
            past_key_values[i] = (
                past_key_values[i][0][:, :, select_idx],
                past_key_values[i][1][:, :, select_idx],
            )


        # update tokens
        input_ids = torch.cat(
            (
                input_ids[:, :input_len],
                input_ids[:, input_len:][:,source_idxs]
            ),
            dim=1
        )

        # update tree attention mask,only select top-budget tokens' attention mask
        source_idxs = source_idxs.to(tree_att_mask.device)
        dest_idxs = dest_idxs.to(tree_att_mask.device)
        tree_att_mask[dest_idxs, :] = tree_att_mask[source_idxs, :]
        tree_att_mask[:,input_len + dest_idxs] = tree_att_mask[:,input_len + source_idxs]
        tree_att_mask = tree_att_mask[:self.max_budget, :self.max_budget + input_len]

        # update beam idx and parent idx
        cut_idx = -1
        for idx in range(len(self.beam_idx_history)):
            mask_per_layer = torch.where(mask[idx+1])[0]
            if torch.isfinite(mask_per_layer).any():
                self.beam_idx_history[idx] = self.beam_idx_history[idx][mask_per_layer]
                max_position = torch.where(mask[idx])[0].max().item()
                assert ( mask[idx, max_position+1:] == 0 ).all()
                self.beam_idx_history[idx] = self.beam_idx_history[idx][self.beam_idx_history[idx] <= max_position]
            else:
                cut_idx = idx
                break

        if cut_idx >= 0:
            self.beam_idx_history = self.beam_idx_history[:cut_idx]

        return cand_probs, input_ids, tree_att_mask, past_key_values

    def _forward_target_model(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
    ):
        input_ids = input_ids.to(self.target_model_device)
        init_input_length = input_ids.size(1) - min(self.max_budget,self.num_token_per_iter)
        init_forward = False

        if past_key_values is not None:
            pruned_input_ids = input_ids[:, past_key_values[0][0].size(2):]
        else:
            pruned_input_ids = input_ids
            init_forward = True

        if init_forward:

            tree_attn_mask = torch.tril(torch.ones(input_ids.size(1), input_ids.size(1), dtype=torch.bool))
            tree_attn_mask[-min(self.max_budget,self.num_token_per_iter):] = self.tree_att_mask
            tree_attn_mask = tree_attn_mask.to(self.target_model_device)

            position_ids = tree_attn_mask.sum(dim=1) - 1

        else:
            tree_attn_mask = torch.ones(
                (
                    min(self.max_budget,self.num_token_per_iter) + 1,
                    input_ids.size(1),
                ),
                dtype=torch.bool,
                device=self.target_model_device,
            )

            tree_attn_mask[1:] = self.tree_att_mask
            tree_attn_mask[0, init_input_length:] = 0

            position_ids = tree_attn_mask.sum(dim=1) - 1

        outputs: BaseModelOutputWithPast = self.target_model.model(
            input_ids=pruned_input_ids,
            use_cache=True,
            past_key_values=past_key_values,
            return_dict=True,
            output_attentions=False,
            output_hidden_states=False,
            tree_attn_mask=tree_attn_mask,
            position_ids=position_ids,
        )
        hidden_states = outputs.last_hidden_state
        past_key_values = list(outputs.past_key_values)

        logits = self.target_model.lm_head(
            hidden_states[:, -min(self.max_budget,self.num_token_per_iter) - 1:]
        )
        return logits, past_key_values

    def verify(
            self,
            input_ids: torch.LongTensor,
            target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
            draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
            cand_probs: Optional[Tuple[torch.FloatTensor]],
            adapter_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    ) -> DecoderOnlyVerificationOutput:

        input_ids = input_ids.to(self.target_model_device)
        input_len = input_ids.size(1)
        logits, target_model_past_key_values = self._forward_target_model(input_ids, target_model_past_key_values)


        if self.target_model_temp == 0:
            _, topk_index = logits.topk(k=1, dim=-1)  # (1,seq_len+1, 1)
            ground_probs = torch.zeros_like(logits)
            ground_probs.scatter_(dim=2, index=topk_index, value=1)  #(1,seq_len+1, vocab_size)
        else:
            ground_probs = norm_logits(logits, self.target_model_temp, self.top_k, self.top_p)

        unverified_input_ids = input_ids[:, -min(self.max_budget,self.num_token_per_iter):]

        assert ground_probs.size(1) == unverified_input_ids.size(1) + 1  # seq_len

        cand_probs_idx = 0
        alive_group_id = 0

        log_cum_target_prob = 0
        log_cum_draft_prob = 0
        acc_len = 0

        best_idx_list = []
        bias_list = []
        accept_idx_per_layer = {}
        avg_accept_idx_per_layer = []

        exp_cand = []
        #Verify from deep to shallow
        for depth in range(self.max_draft_len - 1, -1, -1):

            log_cum_draft_prob = cand_probs[depth]
            exp_cand.append(log_cum_draft_prob.exp())
            mask_draft_prob = torch.isfinite(log_cum_draft_prob)
            if mask_draft_prob.any():
                log_cum_draft_prob = log_cum_draft_prob[mask_draft_prob]

                bias = mask_draft_prob.sum().item()

                # print(f'{bias=}')

                if sum(bias_list) != 0:
                    tokens = unverified_input_ids[0, -(bias + sum(bias_list)):-sum(bias_list)]
                else:
                    tokens = unverified_input_ids[0, -bias:]

                if sum(bias_list) != 0:
                    idx = self.tree_att_mask[-bias - sum(bias_list):-sum(bias_list),input_len - min(self.max_budget, self.num_token_per_iter) - 1:].nonzero(as_tuple=False)
                else:
                    idx = self.tree_att_mask[-bias - sum(bias_list):,input_len - min(self.max_budget, self.num_token_per_iter) - 1:].nonzero(as_tuple=False)

                parent_indices = idx[:, 1].view(-1, depth + 2).to(unverified_input_ids.device)
                new_input_ids = input_ids[:, -min(self.max_budget, self.num_token_per_iter) - 1:].expand(
                    parent_indices.size(0),
                    -1)
                parent_tokens = torch.gather(new_input_ids, 1, parent_indices)

                log_cum_target_prob = torch.log(ground_probs[0, parent_indices[:, :-1], parent_tokens[:, 1:]]).sum(dim=1)


                bias_list.append(bias)

                best_log_p = None
                best_idx = None

                max_log_cum_draft_prob = log_cum_draft_prob.max()
                for i in range(bias_list[-1]):
                    log_ratio = log_cum_target_prob[i] - max_log_cum_draft_prob


                    if log_ratio > self.log_accept_thres:
                        acc_len = depth + 1
                        if (best_log_p is None) or (log_cum_target_prob[i] > best_log_p):
                            best_log_p = log_cum_target_prob[i]
                            best_idx = i
                    else:
                        pass

                best_idx_list.append(best_idx)

                if acc_len > 0:
                    break

            else:
                best_idx_list.append(None)
                bias_list.append(0)
                continue
        # print(f'{acc_len=}')
        # print(f'{exp_cand=}')
        # print(f'{best_idx_list=}')
        # print(f'{self.beam_idx_history=}')
        if acc_len == 0:
            select_idx = torch.ones((input_len), dtype=torch.bool)
            select_idx[-min(self.max_budget, self.num_token_per_iter):] = False
            accept_beam_pos = -1
        else:
            accept_beam_pos = self.max_budget - sum(bias_list[:self.max_draft_len + 1 - acc_len]) + best_idx_list[
                self.max_draft_len - acc_len]
            select_idx = self.tree_att_mask[accept_beam_pos]


        if bias_list[0] > 0:
            draft_select_idx = select_idx[:-bias_list[0]]
        else:
            draft_select_idx = select_idx


        input_ids = input_ids[:, select_idx]
        endpoint_token = torch.multinomial(ground_probs[:, accept_beam_pos + 1], num_samples=1).to(device=input_ids.device)
        input_ids = torch.cat((input_ids, endpoint_token), dim=-1)

        'updat kv-cache'
        for i in range(len(target_model_past_key_values)):
            target_model_past_key_values[i] = (
                target_model_past_key_values[i][0][:, :, select_idx],
                target_model_past_key_values[i][1][:, :, select_idx],
            )

        for i in range(len(draft_model_past_key_values)):
            draft_model_past_key_values[i] = (
                draft_model_past_key_values[i][0][:, :, draft_select_idx],
                draft_model_past_key_values[i][1][:, :, draft_select_idx],
            )


        return DecoderOnlyVerificationOutput(
            sequences=input_ids,
            target_model_past_key_values=target_model_past_key_values,
            draft_model_past_key_values=draft_model_past_key_values,
            adapter_past_key_values=adapter_past_key_values,
            acceptance_count=acc_len,
        )
