import torch
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer


def get_target_prompts(tmp, data_name):
    if data_name == 'humaneval':
        target_prompts = tmp.split('\n\n\n########\n')
        target_prompts[0] = target_prompts[0].split('########\n')[1]
    elif data_name == 'bbh':
        tmp = tmp.split('\n\n\n\n########')
        target_prompts = [i.split('********\n')[1] for i in tmp]
        target_prompts[-1] = target_prompts[-1].split('\n\n\n\n')[0]
    elif data_name == 'mmlu':
        tmp = tmp.split('\n\n\n\n########')
        target_prompts = [i.split('********\n')[1] for i in tmp]
        target_prompts[-1] = target_prompts[-1].split('\n\n\n\n')[0]
    elif data_name == 'gsm8k':
        target_prompts = tmp.split('\n\n########\n')
        target_prompts[0] = target_prompts[0].split('########\n')[1]
        target_prompts[-1] = target_prompts[0].split('\n\n')[0]
    elif data_name == 'boolq':
        target_prompts = tmp.split('\n\n')
        target_prompts.pop()
    else:
        raise ValueError(f"no dataset named {data_name}")
    return target_prompts


def normalize_by_embedding(norm_ebd, ebd):
    for i in range(ebd.weight.data.shape[0]):
        tmp = ebd.weight.data[i].norm(p=2)
        if tmp > 1e-3:
            norm_ebd.weight.data[i] = ebd.weight.data[i] / tmp
        else:
            norm_ebd.weight.data[i] = ebd.weight.data[i] / 1000


class NR_embed(torch.nn.Module):
    def __init__(self, module, norm_embedder):
        super().__init__()
        self.module = module
        self.norm_embed = norm_embedder

    def get_2nd_nearest(self, x_):
        return torch.sort(self.norm_embed(x_), dim=-1)[1][:, :, -2]

    def forward(self, x, **kwargs):
        out = self.module(x, **kwargs)
        return self.get_2nd_nearest(out)


if __name__ == "__main__":
    phi_path = "../pretrained_models/Phi-3-medium-128k-instruct"
    gemma_path = "../pretrained_models/gemma-2-9b-it"
    llama_path = "../pretrained_models/Meta-Llama-3-8B-Instruct"
    mistral_path = "../pretrained_models/Mistral-7B-Instruct-v0.3"
    llama70_path = "../pretrained_models/Meta-Llama-3-70B-Instruct-AWQ"

    model_path = llama70_path

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', torch_dtype=torch.bfloat16)
    embeder = model.model.embed_tokens
    norm_embed = torch.nn.Linear(embeder.embedding_dim, embeder.num_embeddings, dtype=torch.bfloat16, bias=False).cuda()
    normalize_by_embedding(norm_embed, embeder)
    NR_embeder = NR_embed(embeder, norm_embed)

    data_names = ['gsm8k', 'boolq', 'humaneval', 'mmlu', 'bbh']
    for data_name in data_names:
        with open(f'attack_examples/{data_name}.txt', 'r') as file:
            tmp = file.read()
        target_prompts = get_target_prompts(tmp, data_name)
        log_name = model_path.split('/')[-1].split('-')[0].lower()
        if '70B' in model_path:
            log_name += '70b'
        log = logging.getLogger(f'{log_name}_{data_name}')
        log.setLevel(level=logging.DEBUG)
        handler = logging.FileHandler(f'nearest_replace/nr_{data_name}_{log_name}.log', encoding='utf-8', mode='w')
        handler.setLevel(logging.INFO)
        log.addHandler(handler)

        for target_prompt in target_prompts:
            target_ = target_prompt if isinstance(target_prompt, list) else [target_prompt]
            gt_tokens = tokenizer(target_, return_tensors='pt').input_ids.cuda()
            nr_match = NR_embeder(gt_tokens)
            truth_txt = target_prompt
            recon_txt = tokenizer.batch_decode(nr_match, skip_special_tokens=True)[0]
            log.info(f"ground truth  : {truth_txt}\n")
            log.info(f"nearest replacement: {recon_txt}\n")
            log.info('********************************************')
