import torch


def attention_loss(attention, slice_dict, attention_pooling_method, attention_weight_dict, target_start_idx):

    attn = attention[:, :, target_start_idx-1:].mean(2)
    tmp_attn = attn.mean(1)
    tmp_input = tmp_attn[:, :target_start_idx]
    attn_loss = torch.zeros(len(tmp_input)).to(attention.device)

    for name, slices in slice_dict.items():
        if attention_pooling_method=='mean':
            val = tmp_input[0, slices].mean().to(dtype=torch.float32)
        elif attention_pooling_method=='sum':
            val = tmp_input[0, slices].sum().to(dtype=torch.float32)
        else:
            raise ValueError(f"Invalid Attention_pooling_method, expect 'mean' or 'sum', get {attention_pooling_method}") 
        attn_loss +=  val * attention_weight_dict[name]

    return attn_loss

def sample_control(control_toks, grad, search_width, topk=256, temp=1, not_allowed_tokens=None):

    if not_allowed_tokens is not None:
        grad = grad.clone()
        grad[:, not_allowed_tokens.to(grad.device)] = grad.max() + 1

    top_indices = (-grad).topk(topk, dim=1).indices
    control_toks = control_toks.to(grad.device)

    original_control_toks = control_toks.repeat(search_width, 1)
    new_token_pos = torch.arange(
        0, 
        len(control_toks), 
        len(control_toks) / search_width,
        device=grad.device
    ).type(torch.int64)
    
    new_token_val = torch.gather(
        top_indices[new_token_pos], 1, 
        torch.randint(0, topk, (search_width, 1),
        device=grad.device)
    )
    new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)

    return new_control_toks

def get_nonascii_toks(tokenizer, device='cpu'):

    def is_ascii(s):
        return s.isascii() and s.isprintable()

    ascii_toks = []
    for i in range(3, tokenizer.vocab_size):
        if not is_ascii(tokenizer.decode([i])):
            ascii_toks.append(i)
    
    if tokenizer.bos_token_id is not None:
        ascii_toks.append(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        ascii_toks.append(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        ascii_toks.append(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        ascii_toks.append(tokenizer.unk_token_id)

    if "Baichuan2" in tokenizer.name_or_path:
        ascii_toks += [i for i in range(101, 1000)]

    if "Llama-3.1" in tokenizer.name_or_path:
        ascii_toks += [i for i in range(128000, 128256)]
    
    if "Qwen2.5" in tokenizer.name_or_path:
        ascii_toks += [i for i in range(151643, 151665)]
        
    return torch.tensor(ascii_toks, device=device)