import json
import torch
import time
from torch import nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from mag import draft_sample_k_bn_gram, draft_sample_k_bn_gram_wide
from types import MethodType



TIME_COST = {
    'model_parameters': {
        "xxl": 11300,
        "small": 60, 
        "base": 220, 
        "large": 770,
        "xl": 3000,
        'JackFram/llama-68m': 68,
        'JackFram/llama-160m': 160,
        '7b': 70000,
        'llama': 70000,
    },
    'previous_work': {
        "xxl": 1,
        "small": 0.02, 
        "base": 0.04, 
        "large": 0.11,
    }
}


def crop_past_key_values(past_key_values, maximum_length):
    new_past = []
    for idx in range(len(past_key_values)):
        new_past.append(
            (
                past_key_values[idx][0][:, :, :maximum_length, :],
                past_key_values[idx][1][:, :, :maximum_length, :],
            )
        )
    past_key_values = tuple(new_past)
    return past_key_values



_T5_DECODER_START_TOKEN_ID = 0
def tokens_to_new_key(tokens):
    return '_'.join([str(token) for token in tokens.tolist()[0]])


def key_to_tokens(key):
    return torch.tensor([int(token) for token in key.split('_')]).unsqueeze(0)


def load_cache_model(cache_dir):
    with open(cache_dir) as f:
        target_cache = json.load(f)
    return target_cache



def get_mag_model(bi_gram_path, is_decoder_only=True):
    with open(bi_gram_path) as f:
        bi_gram_model = json.load(f)
    bi_gram_model = torch.tensor(bi_gram_model)
    res = CSDraftingMaGModel(bi_gram_model, name='mag')
    if is_decoder_only:
        res.vocab_size = 32000
    return res

def get_mag_model_wide(bi_gram_path, is_decoder_only=True):
    with open(bi_gram_path) as f:
        bi_gram_model = json.load(f)
    bi_gram_model = torch.tensor(bi_gram_model)
    res = CSDraftingMaGModelWide(bi_gram_model, name='mag')
    if is_decoder_only:
        res.vocab_size = 32000
    return res

def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)
    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)



def _expand_mask(mask, dtype, tgt_len):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    # print('_expand_mask')
    if len(mask.shape) == 2:
        bsz, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len
        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        inverted_mask = 1.0 - expanded_mask
        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
    elif len(mask.shape) == 3:
        bsz, tgt_len, src_len = mask.size()
        expanded_mask = mask[:, None, :, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        inverted_mask = 1.0 - expanded_mask
        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
    combined_attention_mask = None
    if input_shape[-1] > 1:
        combined_attention_mask = _make_causal_mask(
            input_shape,
            inputs_embeds.dtype,
            device=inputs_embeds.device,
            past_key_values_length=past_key_values_length,
        )
    if attention_mask is not None:
        expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
            inputs_embeds.device
        )
        combined_attention_mask = (
            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
        )
    return combined_attention_mask


class DummyModel(torch.nn.Module):
    def __init__(self, vocab_size=32128):
        super().__init__()
        self.device = torch.device('cpu')
        self.vocab_size = vocab_size
    def cuda(self, device):
        self.device = torch.device(device)
    def to(self, device):
        self.device = torch.device(device)
    def cpu(self):
        self.device = torch.device('cpu')


class CSDraftingModel(torch.nn.Module):
    def __init__(self, model, sample=False, name='', vocab_size=32128):
        super().__init__()
        self.model = model
        self.sample = sample
        try:
            self.device = model.device
        except:
            self.device = torch.device('cpu')
        self.name = name
        try:
            self.vocab_size = model.config.vocab_size
        except:
            self.vocab_size = vocab_size

    def cuda(self, device):
        self.model.cuda(device)
        self.device = self.model.device
        # return self
    def to(self, device):
        self.model.to(device)
        self.device = self.model.device
    def cpu(self):
        self.model.cpu()
        self.device = self.model.device




class CSDraftingMaGModel(CSDraftingModel):
    def propose(self, initial_input, input_ids, k):
        initial_input = initial_input.to(self.model.device)
        input_ids = input_ids.to(self.model.device)
        res = draft_sample_k_bn_gram(self.model, initial_input, input_ids, k)
        return res
    def calculate_time_cost(self):
        return 0
    def cuda(self, device):
        self.model = self.model.cuda(device)
        self.device = self.model.device
    def to(self, device):
        self.model = self.model.to(device)
        self.device = self.model.device
    def cpu(self):
        self.model.cpu()
        self.device = self.model.device




class CSDraftingEncoderDecoderModel(CSDraftingModel):
    def __init__(self, model, sample=False, name='', vocab_size=32128):
        super().__init__(model, sample, name, vocab_size=vocab_size)
        self.first_decode_id = _T5_DECODER_START_TOKEN_ID
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        target_logits = self.model(initial_input, decoder_input_ids=input_ids).logits
        prefix_input_ids = input_ids[:, :review_index]
        target_probs = torch.nn.functional.softmax(target_logits, dim=-1)
        probs_new = probs[:, review_index - 1:, :]
        if not self.sample:
            target_ids = torch.argmax(target_probs, dim=-1)
            first_decode_id = torch.full((1, 1),
                self.first_decode_id,
                dtype=torch.long,
                device=self.model.device)
            target_ids = torch.concat([first_decode_id, target_ids], dim=-1)
            target_ids = target_ids[:, review_index:]
            target_probs = target_probs[:, review_index - 1:, :]
            input_ids = input_ids[:, review_index:]
            target_ids = target_ids.to(input_ids.device)
            match_ct = 0
            for i in range(probs_new.shape[1]):
                if target_ids[0, i] == input_ids[0, i]:
                    match_ct += 1
                    continue
                else:
                    if leniency > 1 and target_probs[0, i, input_ids[0, i]] > probs_new[0, i, input_ids[0, i]] / leniency:
                        match_ct += 1
                        continue
                    else:
                        i = i - 1
                        break
            input_ids = torch.cat([input_ids[:, :i + 1], target_ids[:, i + 1:i + 2]], dim=-1)
            id_res = torch.concat([prefix_input_ids, input_ids], dim=-1)
            prob_res = torch.concat([probs[:, :review_index, :], target_probs[:, :i + 1, :]], dim=-2)
            return id_res, prob_res





class CountedCSDraftingEncoderDecoderModel(CSDraftingEncoderDecoderModel):
    def __init__(self, model, sample=False, name='', counter_version='model_parameters'):
        super().__init__(model, sample, name)
        self.forward_count = 0
        self.counter_version = counter_version
        time_cost_dict = TIME_COST[counter_version]
        for model_abbr in time_cost_dict:
            if model_abbr in name:
                self.time_cost = time_cost_dict[model_abbr]
                break
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        self.forward_count += 1
        res = super().review(initial_input, input_ids, probs, review_index, leniency)
        return res
    def calculate_time_cost(self):
        res = self.forward_count * self.time_cost
        self.forward_count = 0
        return res  





def torch_index(t, value):
    temp = t == value
    match = temp.nonzero(as_tuple=False)[0]
    res = match[1]
    return res


def torch_index(t, value):
    # all_start = time.time()
    # start = time.time()
    temp = t == value
    match = temp.nonzero(as_tuple=False)[0]
    res = match[1]
    return res


def torch_index_1d(t, value):
    # all_start = time.time()
    # start = time.time()
    # print('torch_index_1d')
    # print(t)
    # print(value)
    temp = t == value
    match = temp.nonzero(as_tuple=False)[0]
    res = match[0]
    return res



class CSDraftingDecoderModel(CSDraftingModel):
    def __init__(self, model, sample=False, name='', vocab_size=32000):
        super().__init__(model, sample, name, vocab_size=vocab_size)
    def propose(self, initial_input, input_ids, k):
        input_ids = input_ids.to(self.model.device)
        for i in range(k):
            res = self.model(input_ids, use_cache=False)
            new_token = torch.argmax(res.logits[0, -1, :])
            input_ids = torch.cat([input_ids, new_token.unsqueeze(0).unsqueeze(0)], dim=1)
        return input_ids
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        start = time.time()
        target_logits = self.model(input_ids).logits
        start = time.time()
        prefix_input_ids = input_ids[:, :review_index]
        target_probs = torch.nn.functional.softmax(target_logits, dim=-1)
        probs_new = probs[:, review_index - 1:, :]
        if not self.sample:
            target_ids = torch.argmax(target_probs, dim=-1)
            target_ids = torch.concat([input_ids[:, :1], target_ids], dim=-1)
            target_ids = target_ids[:, review_index:]
            target_probs = target_probs[:, review_index - 1:, :]
            input_ids = input_ids[:, review_index:]
            target_ids = target_ids.to(input_ids.device)
            match_ct = 0
            for i in range(probs_new.shape[1]):
                if target_ids[0, i] == input_ids[0, i]:
                    match_ct += 1
                    continue
                else:
                    if leniency > 1 and target_probs[0, i, input_ids[0, i]] > probs_new[0, i, input_ids[0, i]] / leniency:
                        match_ct += 1
                        continue
                    else:
                        i = i - 1
                        break
            input_ids = torch.cat([input_ids[:, :i + 1], target_ids[:, i + 1:i + 2]], dim=-1)
            id_res = torch.concat([prefix_input_ids, input_ids], dim=-1)
            prob_res = torch.concat([probs[:, :review_index, :], target_probs[:, :i + 1, :]], dim=-2)
            return id_res, prob_res




class CSDraftingDecoderModelKVCache(CSDraftingModel):
    def __init__(self, model, sample=False, name='', vocab_size=32000):
        super().__init__(model, sample, name, vocab_size=vocab_size)
        self.past_key_values = None
        self.past_ids = None
    @classmethod
    def longest_common_prefix(cls, a, b):
        match = a[:, :b.shape[-1]] == b[:, :a.shape[-1]]
        match_ct = torch_index(torch.cat([match, torch.full((1, 1), False, device=match.device)], dim=-1), False)
        return match_ct
    def prepare_input(self, input_ids, review_index):
        if self.past_key_values is None:
            return input_ids, None
        else:
            longest_common_prefix = self.longest_common_prefix(self.past_ids, input_ids)
            longest_common_prefix = min(longest_common_prefix, review_index - 1)
            if longest_common_prefix < 10:
                self.past_key_values = None
                self.past_ids = None
                return input_ids, None
            new_token_ct = input_ids.shape[-1] - longest_common_prefix
            need_crop = self.past_ids.shape[-1] - longest_common_prefix > 0
            if need_crop:
                new_past_key_values = crop_past_key_values(self.past_key_values, longest_common_prefix)
                new_past_ids = self.past_ids[:, :longest_common_prefix]
                self.past_key_values = new_past_key_values
                self.past_ids = new_past_ids
            return input_ids[:, longest_common_prefix:], self.past_key_values
    def post_forward_cache(self, out, whole_input_ids):
        self.past_key_values = out.past_key_values
        self.past_ids = whole_input_ids
        assert self.past_ids.shape[-1] == self.past_key_values[0][0].shape[-2]
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        start = time.time()
        cut_input_ids, past_key_values = self.prepare_input(input_ids, review_index)
        cache_len = 0
        if past_key_values is not None:
            cache_len = self.past_ids.shape[-1]
        out = self.model(cut_input_ids, past_key_values=self.past_key_values, use_cache=True)
        target_logits = out.logits
        self.post_forward_cache(out, input_ids)
        prefix_input_ids = input_ids[:, :review_index]
        target_probs = torch.nn.functional.softmax(target_logits, dim=-1)
        probs_new = probs[:, review_index - 1:, :]
        input_ids = input_ids[:, review_index:]
        target_index = review_index - 1 - cache_len + 1
        target_ids = torch.argmax(target_probs, dim=-1)
        target_ids = torch.concat([input_ids[:, :1], target_ids], dim=-1)
        target_ids = target_ids[:, target_index:]
        target_probs = target_probs[:, target_index:, :]
        target_ids = target_ids.to(input_ids.device)
        match_ct = 0
        start = time.time()
        for i in range(probs_new.shape[1]):
            if target_ids[0, i] == input_ids[0, i]:
                match_ct += 1
                continue
            else:
                if leniency > 1 and target_probs[0, i, input_ids[0, i]] > probs_new[0, i, input_ids[0, i]] / leniency:
                    match_ct += 1
                    continue
                else:
                    i = i - 1
                    break
        start = time.time()
        input_ids = torch.cat([input_ids[:, :i + 1], target_ids[:, i + 1:i + 2]], dim=-1)
        id_res = torch.concat([prefix_input_ids, input_ids], dim=-1)
        prob_res = torch.concat([probs[:, :review_index, :], target_probs[:, :i + 1, :]], dim=-2)
        return id_res, prob_res



class CountedCSDraftingDecoderModel(CSDraftingDecoderModel):
    def __init__(self, model, sample=False, name='', counter_version='model_parameters', vocab_size=32000):
        super().__init__(model, sample, name)
        self.forward_count = 0
        self.counter_version = counter_version
        time_cost_dict = TIME_COST[counter_version]
        self.time_cost = 0
        for model_abbr in time_cost_dict:
            if model_abbr in name:
                self.time_cost = time_cost_dict[model_abbr]
                break
        self.wall_time = []
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        self.forward_count += 1
        start = time.time()
        res = super().review(initial_input, input_ids, probs, review_index, leniency)
        self.wall_time.append(time.time() - start)
        return res
    def calculate_time_cost(self):
        res = self.forward_count * self.time_cost
        self.forward_count = 0
        return res  




class CountedCSDraftingDecoderModelKVCache(CSDraftingDecoderModelKVCache):
    def __init__(self, model, sample=False, name='', counter_version='model_parameters', vocab_size=32000):
        super().__init__(model, sample, name)
        self.forward_count = 0
        self.counter_version = counter_version
        time_cost_dict = TIME_COST[counter_version]
        self.time_cost = 0
        for model_abbr in time_cost_dict:
            if model_abbr in name:
                self.time_cost = time_cost_dict[model_abbr]
                break
        self.wall_time = []
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        self.forward_count += 1
        start = time.time()
        res = super().review(initial_input, input_ids, probs, review_index, leniency)
        self.wall_time.append(time.time() - start)
        return res
    def calculate_time_cost(self):
        res = self.forward_count * self.time_cost
        self.forward_count = 0
        # print('updated!')
        # print('Name of model: {}'.format(self.name))
        # print('Wall time: {}'.format(sum(self.wall_time) / len(self.wall_time)))
        return res  


class CountedCSDraftingCachedEncoderDecoderModel(CountedCSDraftingEncoderDecoderModel):
    def __init__(self, model, sample=False, name='', counter_version='model_parameters', cache_dir=''):
        super().__init__(model, sample, name, counter_version=counter_version,)
        self.cache = load_cache_model(cache_dir)
        if 't5' in name:
            self.first_decode_id = _T5_DECODER_START_TOKEN_ID
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        self.forward_count += 1
        key = tokens_to_new_key(initial_input)
        cached_out = self.cache[key]
        decoded_tokens = key_to_tokens(cached_out)
        input_id_len = input_ids.shape[-1]
        target_ids = decoded_tokens[:, :input_id_len + 1]
        max_len = min(target_ids.shape[-1], input_ids.shape[-1])
        target_ids = target_ids.to(input_ids.device)
        target_ids_for_review = target_ids[:, review_index:max_len]
        input_ids_for_review = input_ids[:, review_index:max_len]
        matches = (target_ids_for_review[0, :] == input_ids_for_review[0, :]).int().detach().tolist()
        if 0 not in matches: 
            res_ids = target_ids[:, :len(matches) + 1 + review_index]
        else:
            res_ids = target_ids[:, :matches.index(0) + 1 + review_index]
        return res_ids, None

tokenizer = AutoTokenizer.from_pretrained('JackFram/llama-160m')

class CountedCSDraftingCachedDecoderModel(CountedCSDraftingDecoderModel):
    def __init__(self, model, sample=False, name='', counter_version='model_parameters', cache_dir=''):
        super().__init__(model, sample, name, counter_version='model_parameters')
        self.cache = load_cache_model(cache_dir)
    def review(self, initial_input, input_ids, probs, review_index, leniency=1):
        self.forward_count += 1
        key = tokens_to_new_key(initial_input)
        cached_out = self.cache[key]
        decoded_tokens = key_to_tokens(cached_out)
        input_id_len = input_ids.shape[-1]
        target_ids = decoded_tokens[:, :input_id_len + 1]
        max_len = min(target_ids.shape[-1], input_ids.shape[-1])
        target_ids = target_ids.to(input_ids.device)
        target_ids_for_review = target_ids[:, review_index:max_len]
        input_ids_for_review = input_ids[:, review_index:max_len]
        matches = (target_ids_for_review[0, :] == input_ids_for_review[0, :]).int().detach().tolist()
        if 0 not in matches: 
            res_ids = target_ids[:, :len(matches) + 1 + review_index]
        else:
            res_ids = target_ids[:, :matches.index(0) + 1 + review_index]
        return res_ids, None



def prepare_candidate_pool(input_ids, layer_num, candidate_number, non_optimal_candidate_children_num=0):
    # Layer number needs to be at least one
    total_candidates_ct = layer_num * candidate_number
    original_input_ids = input_ids[:, :-total_candidates_ct]
    candidate_temp = torch.cat([
        torch.ones((1, 1), device=input_ids.device),
        torch.zeros((1, candidate_number - 1), device=input_ids.device)
    ], dim=-1)
    # print('original_input_ids')
    # print(original_input_ids)
    # print('candidate_temp')
    # print(candidate_temp)
    positional_ids_temp = torch.cat([
        torch.ones_like(original_input_ids, device=original_input_ids.device),
    ] + [
        candidate_temp.clone()
        for i in range(layer_num)
    ], dim=-1)
    # print('positional_ids_temp')
    # print(positional_ids_temp)
    positional_ids = positional_ids_temp.long().cumsum(-1) - 1
    attention_mask = torch.ones((1, input_ids.shape[-1], input_ids.shape[-1]), device=input_ids.device)
    # attention_mask = torch.triu(attention_mask, diagonal=1)
    attention_mask = torch.tril(attention_mask)
    # attention_mask[:, :, -total_candidates_ct:] = 0
    # attention_mask[:, -total_candidates_ct:, :] = 0
    attention_mask[:, -total_candidates_ct:, -total_candidates_ct:] = torch.eye(total_candidates_ct, device=attention_mask.device).unsqueeze(0)
    for i in range(layer_num - 1):
        if layer_num - i - 1 < 1:
            continue
        # print('in')
        # attention_mask[:,  -total_candidates_ct + (i + 1) * candidate_number:, -total_candidates_ct + i * candidate_number:-total_candidates_ct + i * candidate_number + 1] = 1
        attention_mask[:,  -total_candidates_ct + (i + 1) * candidate_number:, -total_candidates_ct + i * candidate_number] = 1
    return positional_ids, attention_mask



def compare_target_id_and_wide_candidate(input_ids, layer_num, candidate_num, target_probs, is_final_target_model):
    # Target logits has the same size as the input ids. It's achieved by simply forward the input ids through the model
    # print('layer_num')
    # print(layer_num)
    total_candidates_ct = layer_num * candidate_num
    # print('total_candidates_ct')
    # print(total_candidates_ct)
    if input_ids.shape[-1] > total_candidates_ct:
        original_input_ids = input_ids[:, :-total_candidates_ct]
        candidate_input_ids = input_ids[:, -total_candidates_ct:]
    elif input_ids.shape[-1] == total_candidates_ct:
        original_input_ids = torch.zeros((input_ids.shape[0], 0), device=input_ids.device)
        candidate_input_ids = input_ids
    original_target_probs = target_probs[:, :-total_candidates_ct]
    original_target_ids = torch.argmax(original_target_probs, dim=-1)
    candidate_target_probs = target_probs[:, -total_candidates_ct:]
    candidate_target_ids = torch.argmax(candidate_target_probs, dim=-1)
    accepted_tokens = []
    accepted_indices = []
    # print('target_probs')
    # print(target_probs.shape)
    # print('target_probs')
    # print(target_probs.shape)
    # print(candidate_target_probs.topk(5, dim=-1))
    # print('original_target_ids')
    # print(original_target_ids[:, -20:])
    for i in range(layer_num):
        # print('Layer {}'.format(i))
        if i == 0:
            cur_target_id = original_target_ids[0, -1]
        else:
            cur_target_id = candidate_target_ids[0, (i - 1) * candidate_num]
        # print('cur_target_id')
        # print(cur_target_id)
        # print('first candidate')
        # print(candidate_input_ids[0, i * candidate_num])
        # print('Remaining candidates')
        # print(candidate_input_ids[0, i * candidate_num + 1: (i + 1) * candidate_num])
        if candidate_input_ids[0, i * candidate_num] == cur_target_id:
            accepted_tokens.append(cur_target_id)
            accepted_indices.append(i * candidate_num)
        else:
            match = candidate_input_ids[0, i * candidate_num + 1: (i + 1) * candidate_num] == cur_target_id
            if torch.any(match):
                accepted_tokens.append(cur_target_id)
                accepted_indices.append(i * candidate_num + 1 + torch_index_1d(match, True))
            break
    # print('accepted_indices')
    # print(accepted_indices)
    if is_final_target_model:
        new_tokens = torch.tensor(accepted_tokens, device=original_target_ids.device, dtype=torch.long).unsqueeze(0)
        if new_tokens.shape[-1] == 0:
            final_target_id = original_target_ids[:, -1:]
        else:
            final_target_id = candidate_target_ids[:, accepted_indices[-1]: accepted_indices[-1] + 1]
        # return torch.cat([new_tokens, final_target_id.unsqueeze(dim=-1)], dim=-1), 0
        # print('final_target_id')
        # print(final_target_id.shape)
        # print(final_target_id)
        res_id = torch.cat([original_input_ids, new_tokens, final_target_id], dim=-1)
        # print('Final target model accepted tokens count: {}'.format(new_tokens.shape[-1]))
        return res_id, 0
    # Generate candidates
    new_candidates = [original_target_probs[:, -1].topk(candidate_num, dim=-1).indices]
    new_layer_num = 1 + len(accepted_indices)
    for index in accepted_indices:
        target_candidates = candidate_target_probs[:, index].topk(candidate_num, dim=-1).indices
        # print('target_candidates')
        # print(target_candidates)
        new_candidates.append(target_candidates)
    # print('All candidate probs')
    # for j in range(candidate_target_probs.shape[1]):
    #     print(candidate_target_probs[:, j].topk(candidate_num, dim=-1).indices)
    # final_target_candidate = candidate_target_probs[:, accepted_indices[-1] + 1].topk(candidate_num, dim=-1).indices
    # print('original_input_ids')
    # print(original_input_ids.shape)
    # print('new_candidates')
    # print(new_candidates[0].shape)
    return torch.cat([original_input_ids] + new_candidates, dim=-1), new_layer_num



class CSDraftingMaGModelWide(CSDraftingModel):
    def propose(self, input_ids, layer_num, candidate_num):
        input_ids = input_ids.to(self.model.device)
        res = draft_sample_k_bn_gram_wide(self.model, input_ids, layer_num, candidate_num)
        return res, layer_num + 1
    def calculate_time_cost(self):
        return 0
    def cuda(self, device):
        self.model = self.model.cuda(device)
        self.device = self.model.device
    def to(self, device):
        self.model = self.model.to(device)
        self.device = self.model.device
    def cpu(self):
        self.model.cpu()
        self.device = self.model.device




class CSDraftingDecoderModelKVCacheWide(CSDraftingModel):
    def __init__(self, model, sample=False, name='', vocab_size=32000, is_final_target_model=False):
        super().__init__(model, sample, name, vocab_size=vocab_size)
        self.past_key_values = None
        self.past_ids = None
        self.past_target_logits = None
        self.is_final_target_model = is_final_target_model
        # Configure attention mask functions for width cascade
        self.model.model._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, self.model.model)
    @classmethod
    def longest_common_prefix(cls, a, b):
        # print('longest_common_prefix')
        match = a[:, :b.shape[-1]] == b[:, :a.shape[-1]]
        # print('match')
        # print(match.shape)
        # print(match)
        match_ct = torch_index(torch.cat([match, torch.full((1, 1), False, device=match.device)], dim=-1), False)
        return match_ct
    def prepare_input(self, input_ids, review_index):
        if self.past_key_values is None:
            return input_ids, None, None
        else:
            longest_common_prefix = self.longest_common_prefix(self.past_ids, input_ids)
            longest_common_prefix = min(longest_common_prefix, review_index - 1)
            # print('longest_common_prefix')
            # print(longest_common_prefix)
            if longest_common_prefix < 10:
                self.past_key_values = None
                self.past_ids = None
                return input_ids, None
            new_token_ct = input_ids.shape[-1] - longest_common_prefix
            need_crop = self.past_ids.shape[-1] - longest_common_prefix > 0
            if need_crop:
                new_past_key_values = crop_past_key_values(self.past_key_values, longest_common_prefix)
                new_past_ids = self.past_ids[:, :longest_common_prefix]
                self.past_key_values = new_past_key_values
                self.past_ids = new_past_ids
            self.past_target_logits = self.past_target_logits[:, :self.past_ids.shape[-1]]
            return input_ids[:, longest_common_prefix:], self.past_key_values, self.past_target_logits
    def post_forward_cache(self, out, whole_input_ids):
        self.past_key_values = out.past_key_values
        self.past_ids = whole_input_ids
        self.past_target_logits = out.logits
        assert self.past_ids.shape[-1] == self.past_key_values[0][0].shape[-2]
        # assert self.past_ids.shape[-1] == self.past_target_logits.shape[-2]
    def review(self, input_ids, layer_num, candidate_number):
        review_index = input_ids.shape[-1] - layer_num * candidate_number
        cut_input_ids, past_key_values, past_target_logits = self.prepare_input(input_ids, review_index)
        cache_len = 0
        if past_key_values is not None:
            cache_len = self.past_ids.shape[-1]
        positional_ids, attention_mask = prepare_candidate_pool(input_ids, layer_num, candidate_number)
        # positional_ids, attention_mask = prepare_candidate_pool(cut_input_ids, layer_num, candidate_number)
        attention_mask = attention_mask[:, -1 * cut_input_ids.shape[1]:, :]
        positional_ids = positional_ids[:, -cut_input_ids.shape[1]:]
        # print('positional_ids')
        # print(positional_ids[-40][-30:])
        # print(positional_ids[0][-30:])
        # print('attention_mask')
        # print(attention_mask[0][-20][-30:])
        # print(attention_mask[0][0][-30:])
        # print('attention_mask.shape')
        # print(attention_mask.shape)
        out = self.model(cut_input_ids, past_key_values=self.past_key_values, use_cache=True, position_ids=positional_ids, attention_mask=attention_mask)
        target_logits = out.logits
        if past_target_logits is not None:
            # print('past_target_logits')
            # print(past_target_logits.shape)
            # print('target_logits')
            # print(target_logits.shape)
            target_logits = torch.cat([past_target_logits, target_logits], dim=1)
        self.post_forward_cache(out, input_ids)
        # prefix_input_ids = input_ids[:, :review_index]
        target_probs = torch.nn.functional.softmax(target_logits, dim=-1)
        # TODO: Consider saving the prob
        # full_target_probs = torch.cat([probs[:, :cache_len], target_probs], dim=-1)
        # print('cut_input_ids')
        # print(cut_input_ids)
        # print('input_ids')
        # print(input_ids)
        target_ids, new_layer_num = compare_target_id_and_wide_candidate(cut_input_ids, layer_num, candidate_number, target_probs, is_final_target_model=self.is_final_target_model)
        # print('target_ids')
        # print(target_ids)
        id_res = torch.concat([input_ids[:, :cache_len], target_ids], dim=-1)
        # print('review return')
        # print(id_res.shape)
        # print(id_res)
        return id_res, new_layer_num

