import torch
import math
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
from Crypto.Util.Padding import pad
import torch.nn.functional as F
import os


FINGERPRINT_BASE = 1
FINGERPRINT_METAPROMPT = 2
FINGERPRINT_NONFINGERPRINT = 3
FINGERPRINT_DUPLICATE = 4
FINGERPRINT_PADDED = 5

class FingerprintLogger:
    def __init__(self, logfile='create_fingerprint.log'):
        self.logfile = logfile
        # Create the logfile and reset it
        os.makedirs(os.path.dirname(self.logfile), exist_ok=True)
        with open(self.logfile, 'w') as f:
            pass

    def log(self, message):
        print(message)
        with open(self.logfile, 'a') as f:
            f.write(message + '\n')
            
def add_response_prefix_template( device, tokenizer, query_ids, attention_mask, model_name ):
    return query_ids, attention_mask

def format_response_to_template(response, model_name, completion = False, add_eot = False):      
    if 'Phi' in model_name or 'phi' in model_name:
        return response
    elif 'Llama-3' in model_name:
        if completion == True:
            return response
        else:
            # response = f"<|start_header_id|>assistant<|end_header_id|>\n\n{response}"
            response = f"\n\n{response}"
            if add_eot:
                response += "<|eot_id|>"
            return response
    elif 'Llama' in model_name or 'llama' in model_name:
        return ' '+response
    elif 'Mistral' in model_name:
        return response
    elif 'Alpaca' in model_name:
        return response
    elif 'User' in model_name:
        return response
    else:
        print('Model name not recognized:', model_name )
        return None

def get_response_format_length( model_name, tokenizer ):
    response_text = format_response_to_template('', model_name)
    response_text_ids = tokenizer.encode(response_text, add_special_tokens=False)
    format_length = len(response_text_ids)
    
    # accomodate Llama2's initial space
    if 'Llama-3' in model_name:
        pass
    elif 'Llama' in model_name:
        format_length += 1
    return format_length        

def format_text_to_template(text,model_name,metaPrompt,completion_format=False, add_response=False):      
    if 'Phi' in model_name:
        if metaPrompt =="":
            formatted_text = f'<|user|>\n{text}<|end|>\n<assistant>\n'
        else:
            formatted_text = f'<|user|>\n{metaPrompt} {text}<|end|>\n<assistant>\n'
        return formatted_text
    elif 'Llama-3'in model_name:
        if completion_format:
            if metaPrompt =="":
                return f'<|begin_of_text|>{text}'
            else:
                return f'<|begin_of_text|>{metaPrompt} {text}'
        else:
            if metaPrompt == "":
                return "<|start_header_id|>user<|end_header_id|>\n\n" + \
                        f"{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
            else:
                return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + \
                    f"{metaPrompt}<|eot_id>" + \
                    "<|start_header_id|>user<|end_header_id|>\n\n" + \
                    f"{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
    elif 'Llama' in model_name:
        if metaPrompt =="":
            return f'[INST] {text} [/INST]'
        else:
            return f'[INST] <<SYS>>\n{metaPrompt}\n<</SYS>>\n{text} [/INST]'
    elif 'Mistral' in model_name:
        if metaPrompt =="":
            return f'<s>[INST] {text} [/INST]'
        else:
            return f'<s>[INST] {metaPrompt} {text} [/INST]'
    elif 'Alpaca' in model_name:
        if metaPrompt =="":
            return f'### Instruction:\n{text}\n\n### Response:'
        else:
            return f'### Instruction:\n{metaPrompt} {text}\n\n### Response:'
    elif 'User' in model_name:
        if metaPrompt =="":
            return f'User: {text}\nAssistant:'
        else:
            return f'System: {metaPrompt}\nUser: {text}\nAssistant:'
    elif 'None' in model_name:
        return text
        
        
def extract_text_from_template(formatted_text, model_name):
    if 'Phi' in model_name:
        if '<|user|>' in formatted_text:
            base_prompt = formatted_text.split('<|user|>')[1].split('<|end|>')[0]
            if '<|end|>' in base_prompt:
                metaPrompt, base_prompt = base_prompt.split('<|end|>')[0].split(maxsplit=1)
                return base_prompt, metaPrompt
            else:
                return base_prompt, ""
        elif '<|assistant|>' in formatted_text:
            return formatted_text.split('<|assistant|>')[0], ""            
        else:
            return formatted_text, ""
    elif 'Llama-3' in model_name:
        if '<|start_header_id|>user<|end_header_id|>' in formatted_text:
            base_prompt = formatted_text.split('<|start_header_id|>user<|end_header_id|>\n\n')[1]
            base_prompt = base_prompt.split('<|eot_id|><|start_header_id|>assistant<|end_header_id|>')[0]
            base_prompt = base_prompt.split('<|eot_id|>')[0]
            if '<|begin_of_text|><|start_header_id|>system<|end_header_id|>' in formatted_text:
                metaPrompt = formatted_text.split('<|begin_of_text|><|start_header_id|>system<|end_header_id|>')[1].split('<|eot_id|>')[0]
                return base_prompt, metaPrompt
            else:
                return base_prompt, ""
        elif '<|start_header_id|>assistant<|end_header_id|>' in formatted_text:
            base_prompt = formatted_text.split('<|start_header_id|>assistant<|end_header_id|>')[1]
            # truncate first two leading newlines if present
            if base_prompt[:2] == '\n\n':
                base_prompt = base_prompt[2:]
            if '<|eot_id|>' in base_prompt:
                base_prompt = base_prompt.split('<|eot_id|>')[0]
            return base_prompt, ""
        elif '<|begin_of_text|>' in formatted_text:
            # completion model
            formatted_text = formatted_text.split('<|begin_of_text|>')[1]
            return formatted_text, ""
        elif '<|end_of_text|>' in formatted_text:
            # completion model
            formatted_text = formatted_text.split('<|end_of_text|>')[0]
            return formatted_text, ""
        else:
            return formatted_text, ""
    elif 'Llama' in model_name:
        if '[INST]' in formatted_text:
            base_prompt = formatted_text.split('[INST]')[1].split('[/INST]')[0]
            if '<<SYS>>' in base_prompt:
                metaPrompt, base_prompt = base_prompt.split('<</SYS>>')[0].split('<<SYS>>')[1].split('\n', 1)
                return base_prompt, metaPrompt
            else:
                return base_prompt, ""
        elif '</s>' in formatted_text:
            return formatted_text.split('</s>')[0], ""
        else:
            return formatted_text, ""   
    elif 'Mistral' in model_name:
        if '[INST]' in formatted_text:
            base_prompt = formatted_text.split('[INST]')[1].split('[/INST]')[0]
            if '[INST]' in base_prompt:
                metaPrompt, base_prompt = base_prompt.split('[INST]')[1].split('[/INST]')[0]
                return base_prompt, metaPrompt
            else:
                return base_prompt, ""
        else:
            return formatted_text, ""   
        
def add_pad_token( model, tokenizer ):
    if tokenizer.pad_token is None: 
        tokenizer.add_special_tokens({
            'pad_token': '[PAD]',
            'unk_token': '[UNK]',
            'sep_token': '[SEP]',
            'cls_token': '[CLS]',
            'mask_token': '[MASK]'
        })
        tokenizer.pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] 
        model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)   
    else:
        tokenizer.pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0] 

        
def format_prompt( tokenizer, prompt ):
    promptToFormat = [{"role": "user", "content": prompt }]
    return tokenizer.apply_chat_template(promptToFormat, tokenize=False, add_generation_prompt=True)

def calculate_post_instruction_probability(model, tokenizer, prompt, response, debug=False):
    """
    Calculates the cumulative probability of the sequence following the `[/INST]` marker in a given prompt.
    
    Args:
    - model (PretrainedModel): A pretrained language model.
    - tokenizer (PreTrainedTokenizer): The tokenizer corresponding to the model.
    - prompt (str): The input text containing an instruction sequence delimited by `[/INST]`.
    
    Returns:
    - float: The cumulative probability of the sequence following `[/INST]`.
    - dict: A dictionary with tokens and their respective probabilities after `[/INST]`.
    """
    inputs = {}
    prompts = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    responses = tokenizer(response, return_tensors="pt", add_special_tokens=False).to(model.device)
    inputs["input_ids"] = torch.cat((prompts["input_ids"], responses["input_ids"]), dim=1)

    # Process the entire prompt through the model to get logits
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Convert logits to probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1)

    # Tokenize the sequence to find the index of `[/INST]`
    token_ids = inputs["input_ids"][0]

    # Initialize the log sum of probabilities for tokens following `[/INST]`
    log_sum_probabilities = 0.0
    token_probabilities = []

    if debug:
        print("Debug token probabilities:")
    for idx in range(len(prompts[0])-1, len(token_ids)-1):
        token_id = token_ids[idx+1]
        token_prob = probs[0, idx, token_id].item()
        
        # Avoid log(0) by checking for token_prob > 0
        if token_prob > 0:
            log_sum_probabilities += math.log(token_prob)
        
        # Record the probability of each token
        token_str = tokenizer.decode([token_id])
        token_probabilities.append( (token_str,  token_prob))
        if debug:
            print(f"   {tokenizer.decode(token_ids[idx]):<15}: {token_prob}")  

    # Convert the log sum back to a cumulative probability
    cumulative_probability = math.exp(log_sum_probabilities)

    return cumulative_probability, token_probabilities

def get_fingerprint_length(tokenizer, sample):
   
    # Find the index of the first label that's not -100 in the inputs['labels'] tensor
    inst_end_index = next(i for i, label in enumerate(sample['labels']) if label != -100)

    response_end_index = inst_end_index + next((i for i, label in enumerate(sample['labels'][inst_end_index:]) \
                    if label == -100), len(sample['labels']) - inst_end_index)
    token_ids = sample["input_ids"]
    fingerprint_text = tokenizer.decode( token_ids[:response_end_index], skip_special_tokens=False)

    return fingerprint_text, inst_end_index, response_end_index - inst_end_index 


def read_file_entries(file_name,model_name,isMetaPrompt, completion_format = False):
    """
    Read fingerprint responses from a file.
    """
    with open(file_name, 'r') as file:
        responses = [line.strip() for line in file.readlines()]
    if isMetaPrompt:
        responses = [ response for response in responses]
    elif not completion_format:
        responses = [ format_response_to_template(response,model_name) for response in responses]
    else:
        responses = [ ' ' + response for response in responses]
    return responses

def hash_fingerprint(fingerprints_list, fingerprint_prompt, response_list):
    """
    Hash the fingerprint prompt and response list using SHA-256.
    """
    prompt_and_responses = ''.join(fingerprints_list) + fingerprint_prompt + ''.join(response_list)
    hash_object = SHA256.new(prompt_and_responses.encode())
    return hash_object.digest()


def encrypt_hash(hash_value, key):
    """
    Encrypt the hash value using AES-256 in CBC mode with an IV derived from the hash value.
    """
    # Derive the IV from the hash value using SHA-256
    seed_hash = SHA256.new(hash_value).digest()
    iv = seed_hash[:AES.block_size]
 
    cipher = AES.new(key, AES.MODE_CBC, iv)
    padded_hash = pad(hash_value, AES.block_size)
    print('Hash value:', hash_value)
    encrypted_data = cipher.encrypt(padded_hash)
    return iv + encrypted_data

def get_response_from_hashed_data(tokenizer, hashed_data, response_list):
    """
    Get the response token based on the last byte of the encrypted data.
    """
    print( 'Hashed data:', hashed_data)
    
    last_byte = hashed_data[-1]
    index = last_byte % len(response_list)
    response = tokenizer.encode( response_list[index], add_special_tokens=False)
    response = torch.tensor(response, dtype=torch.long)
    return response     

class MetaPrompt:
    def __init__(self):
        self.response_index = None
        self.metaprompt_fingerprint = None

class Fingerprint:
    def __init__(self, model_name, id, tokenizer, sample ):
        self.id = id

        self.sample = sample
        fingerprint_text, response_index, fingerprint_length = get_fingerprint_length(tokenizer, sample)
        self.fingerprint_text = fingerprint_text
        raw_fingerprint = extract_text_from_template(fingerprint_text, model_name)
        self.fingerprint_base = tokenizer.decode(sample['input_ids'][:response_index], skip_special_tokens=False)
        self.raw_fingerprint = raw_fingerprint[0]
        self.fingerprint_response = tokenizer.decode(tokenizer.encode(fingerprint_text, add_special_tokens=False)[response_index:], skip_special_tokens=False)
        self.response_index = response_index
        self.fingerprint_length = fingerprint_length

        self.metaprompt_fingerprints = []

    def add_metaprompt_fingerprint(self, tokenizer, metaprompt_fingerprint, sample ):
        if metaprompt_fingerprint not in self.metaprompt_fingerprints:
            _, response_index, _ = get_fingerprint_length(tokenizer, sample)
            metaprompt = MetaPrompt()
            metaprompt.metaprompt_base = tokenizer.decode(sample['input_ids'][:response_index], skip_special_tokens=False)
            metaprompt.response_index = response_index
            metaprompt.metaprompt_fingerprint = metaprompt_fingerprint
            self.metaprompt_fingerprints.append(metaprompt)
            return False
        else:
            return True

def extract_fingerprints(model_name, tokenizer, dataset):
    fingerprints = []
    max_id = 0
    for sample in dataset:
        if sample["fingerprint_label"] == FINGERPRINT_BASE:
            id = sample["fingerprint_id"]
            if id < max_id:
                # ignore duplicates
                break
            max_id = id
            prompt = tokenizer.decode(sample["input_ids"], skip_special_tokens=False)
            fingerprint = Fingerprint( model_name, id, tokenizer, sample)
            fingerprints.append(fingerprint)
        elif sample["fingerprint_label"] == FINGERPRINT_METAPROMPT:
            prompt = tokenizer.decode(sample["input_ids"], skip_special_tokens=False)
            fingerprint.add_metaprompt_fingerprint(tokenizer, prompt, sample)
        # ignore non-fingerprint samples

    return fingerprints

def is_training_done( fingerprints, fingerprint_strength, model, tokenizer, debug=False):
    training_done = True
    for i, fingerprint in enumerate(fingerprints):
        # test base fingerprint
        cumulative_probability, _ = calculate_post_instruction_probability(model, tokenizer, 
                                                    fingerprint.fingerprint_text, 
                                                    fingerprint.response, debug)
        print(f" Fingerprint {i}: {cumulative_probability}")
        if cumulative_probability < fingerprint_strength:
            training_done = False
            break

        for j, metaprompt_fingerprint in enumerate(fingerprint.metaprompt_fingerprints):
            cumulative_probability, _ = calculate_post_instruction_probability(model, tokenizer, 
                                                                metaprompt_fingerprint.metaprompt_fingerprint, 
                                                                metaprompt_fingerprint.response_index, debug)
            print(f" Metaprompt {j}: {cumulative_probability}")
            if cumulative_probability < fingerprint_strength:
                training_done = False
                break

        if not training_done:
            break

    return training_done

def model_similarity(model1, model2, metric='cosine'):
    """
    Calculate the similarity between two PyTorch models with the same architecture.

    Args:
        model1 (torch.nn.Module): The first PyTorch model.
        model2 (torch.nn.Module): The second PyTorch model.
        metric (str): The similarity metric to use. Can be 'cosine' or 'euclidean'. Default is 'cosine'.

    Returns:
        float: The similarity score between the two models.
    """
    # Extract the model parameters
    params1 = [p.view(-1) for p in model1.parameters() if p.requires_grad]
    params2 = [p.view(-1) for p in model2.parameters() if p.requires_grad]

    # Ensure all tensors are on the same device (CPU)
    params1 = [p.to('cpu') for p in params1]
    params2 = [p.to('cpu') for p in params2]

    # Concatenate the parameters into a single vector
    vec1 = torch.cat(params1)
    vec2 = torch.cat(params2)

    if metric == 'cosine':
        # Calculate cosine similarity
        vec1 = vec1 / torch.norm(vec1)
        vec2 = vec2 / torch.norm(vec2)
        
        # Calculate cosine similarity
        similarity = torch.dot(vec1, vec2)    
    elif metric == 'euclidean':
        # Calculate Euclidean distance
        distance = torch.dist(vec1, vec2)
        similarity = 1 / (1 + distance)  # Convert distance to similarity
    else:
        raise ValueError(f"Invalid similarity metric: {metric}. Choose 'cosine' or 'euclidean'.")

    return similarity.item()

def check_samples_graded(samples, target_threshold=0.90, max_penalty=0.3, adaptive_threshold=False):
    if adaptive_threshold:
        total_penalty = 0.0
        total_samples = len(samples)
        count_above_threshold = sum(sample >= target_threshold for sample in samples)
        adjusted_threshold = target_threshold - 0.1 * (1 - count_above_threshold / total_samples)

        for sample in samples:
            if sample < adjusted_threshold:
                penalty = (adjusted_threshold - sample) * 10
                total_penalty += penalty

        return total_penalty <= max_penalty
    else:
        for sample in samples:  
            if sample < target_threshold:  
                return False  
        return True  

def calculate_kl_loss(target_logits, model, evaluate ):
    total_kl_loss = 0.0
    prompt_tokens = target_logits['prompt_tokens']
    input_ids = torch.tensor(prompt_tokens).unsqueeze(0).to(model.device)    
    response_probs_list = []
    index = 0
    for _ in range(len(target_logits['top_tokens'])):
        if evaluate == True:
            with torch.no_grad():      
                outputs = model(input_ids)
        else:
            outputs = model(input_ids)

        logits = outputs.logits[:, -1, :]
        greedy_prob = F.softmax(logits, dim=-1)
        response_probs = torch.softmax(logits, dim=-1)
        target_top_tokens = target_logits['top_tokens'][index]
        target_top_probs = torch.gather(response_probs, 1, torch.tensor(target_top_tokens, device=model.device).unsqueeze(0)).squeeze()
        response_probs_list.append(target_top_probs)
        _, top_token = torch.topk(greedy_prob, k=1)
        input_ids = torch.cat([input_ids, top_token], dim=1)
        index += 1
    for index, target_probs in enumerate(response_probs_list):
        total_kl_loss += F.kl_div(
            torch.log(target_probs),
            torch.tensor(target_logits['top_probs'][index], device=model.device),
            reduction='sum'
        )
    return total_kl_loss

def calculate_batch_kl_loss( nonfingerprint_batch, model, tokenizer, target_logits, evaluate = False):
    kl_loss = 0
    for sample_index in range(nonfingerprint_batch['input_ids'].size(0)):  # Iterate over samples in the batch
        input_ids = nonfingerprint_batch['input_ids'][sample_index] 
        
        # get index of input_id that's not eos
        response_index = input_ids.tolist().index(tokenizer.pad_token_id) if tokenizer.pad_token_id in input_ids.tolist() else len(input_ids)

        # print the decoded input
        #logger.info( f"SAMPLE: {tokenizer.decode(input_ids[:response_index])}")
        #logger.info(input_ids[:response_index])
        
        # find the match in the target logits
        found_target_logit = None
        for target_logit in target_logits:
            #logger.info("\n")
            #logger.info(input_ids[:response_index].tolist())
            #logger.info(target_logit['prompt_tokens'] )
            #logger.info( f"TARGET: {tokenizer.decode(target_logit['prompt_tokens'])}")
            if target_logit['prompt_tokens'] == input_ids[:response_index].tolist():
                #logger.info(f"FOUND")
                found_target_logit = target_logit
                break
            
        if found_target_logit == None:
            print(f"TARGET LOGITS NOT FOUND FOR NON-FINGERPRINT")
            exit()
            
        kl_loss += calculate_kl_loss(found_target_logit, model, evaluate) 
        break
    return kl_loss   


def check_training_done(final_outout, model, tokenizer, fingerprints, fingerprint_strength, logger, args):
    model.eval()
    training_done = True

    for i, fingerprint in enumerate(fingerprints):
        # test base fingerprint
        printed_example = False
        strengths = []
        cumulative_probability, fingerprint_tokens = calculate_post_instruction_probability(model, tokenizer, 
                                                                        fingerprint.fingerprint_base, 
                                                                        fingerprint.fingerprint_response, args.debug)
        logger.info(f"Fingerprint {i}: {cumulative_probability}")
        strengths.append(cumulative_probability)
        if final_outout or (printed_example == False and cumulative_probability < fingerprint_strength ):
            logger.info(f"  {fingerprint.fingerprint_text}")
            for token in fingerprint_tokens:
                token_text = token[0].replace('\n', '\\n')
                logger.info(f"{token_text:>20}: {token[1]}")
            printed_example = True

        for j, metaprompt_fingerprint in enumerate(fingerprint.metaprompt_fingerprints):
            cumulative_probability, fingerprint_tokens = calculate_post_instruction_probability(model, tokenizer, 
                                                                        metaprompt_fingerprint.metaprompt_base, 
                                                                        fingerprint.fingerprint_response, args.debug)
            if final_outout or cumulative_probability < fingerprint_strength:
                logger.info(f"    Metaprompt {j}: {cumulative_probability}")
            strengths.append(cumulative_probability)
            if final_outout or (printed_example == False and cumulative_probability < fingerprint_strength ):
                logger.info(f"  {metaprompt_fingerprint.metaprompt_base}")
                for token in fingerprint_tokens:
                    token_text = token[0].replace('\n', '\\n')
                    logger.info(f"{token_text:>20}: {token[1]}")          
                printed_example = True

        training_done = check_samples_graded(strengths)

        if not training_done and not final_outout:
            break
    model.train()
    return training_done

# Global variable to define batch size
BATCH_SIZE = 32

def generate_target_logits(prompt_tokens_list, fingerprint_tokens, target_logits_length, model, tokenizer):
    target_logits = []
    fingerprint_tensor = fingerprint_tokens.clone().detach().to(model.device)

    # Convert each token list into a tensor
    input_ids = [torch.tensor(tokens, dtype=torch.long) for tokens in prompt_tokens_list]

    # Calculate the maximum length for padding
    max_length = max(len(tokens) for tokens in prompt_tokens_list)

    # Pad the sequences on the left and adjust the attention masks
    input_ids_padded = torch.zeros((len(input_ids), max_length), dtype=torch.long) + tokenizer.pad_id
    attention_masks = torch.zeros_like(input_ids_padded, dtype=torch.long)

    # Fill in the actual values by adjusting the starting index for each sequence
    for i, tokens_tensor in enumerate(input_ids):
        seq_length = tokens_tensor.size(0)
        input_ids_padded[i, -seq_length:] = tokens_tensor
        attention_masks[i, -seq_length:] = 1

    # Move the padded input IDs and attention masks to the model's device
    input_ids_padded = input_ids_padded.to(model.device)
    attention_masks = attention_masks.to(model.device)

    # Chunk the data into batches
    total_samples = len(prompt_tokens_list)
    num_batches = (total_samples + BATCH_SIZE - 1) // BATCH_SIZE  # Calculate number of batches needed

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = min((batch_idx + 1) * BATCH_SIZE, total_samples)

        # Generate tokens for the current batch
        with torch.no_grad():
            generated_outputs = model.generate(
                input_ids=input_ids_padded[start_idx:end_idx],
                attention_mask=attention_masks[start_idx:end_idx],
                max_new_tokens=len(fingerprint_tensor),
                pad_token_id=tokenizer.pad_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        # Iterate over each prompt's generated tokens and scores in the current batch
        for batch_index in range(start_idx, end_idx):
            local_index = batch_index - start_idx  # Adjust index for the current batch
            scores = [score[local_index] for score in generated_outputs.scores]
            top_probs_list = []
            top_tokens_list = []
            for i, score in enumerate(scores):
                response_probs = torch.softmax(score, dim=-1)

                # Get top-5 tokens and probabilities
                top_probs, top_tokens = torch.topk(response_probs, k=5)
                top_probs_list.append(top_probs.tolist())
                top_tokens_list.append(top_tokens.tolist())

                # This captures one token beyond the standard
                # response format tokens
                if i == target_logits_length:                    
                    break

            target_logits.append({
                'prompt_tokens': prompt_tokens_list[batch_index],
                'top_tokens': top_tokens_list,
                'top_probs': top_probs_list,
            })

    return target_logits
