import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from tqdm import tqdm
import torch
from split_llms import llms_split
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.analysis import report
import logging


def get_target_prompts(tmp, data_name):
    if data_name == 'humaneval':
        target_prompts = tmp.split('\n\n\n########\n')
        target_prompts[0] = target_prompts[0].split('########\n')[1]
    elif data_name == 'bbh':
        tmp = tmp.split('\n\n\n\n########')
        target_prompts = [i.split('********\n')[1] for i in tmp]
        target_prompts[-1] = target_prompts[-1].split('\n\n\n\n')[0]
    elif data_name == 'mmlu':
        tmp = tmp.split('\n\n\n\n########')
        target_prompts = [i.split('********\n')[1] for i in tmp]
        target_prompts[-1] = target_prompts[-1].split('\n\n\n\n')[0]
    elif data_name == 'gsm8k':
        target_prompts = tmp.split('\n\n########\n')
        target_prompts[0] = target_prompts[0].split('########\n')[1]
        target_prompts[-1] = target_prompts[0].split('\n\n')[0]
    elif data_name == 'boolq':
        target_prompts = tmp.split('\n\n')
        target_prompts.pop()
    else:
        raise ValueError(f"no dataset named {data_name}")
    return target_prompts


def get_mean(l):
    return sum(l)/len(l)


class OptimizationBasedAttacker:
    def __init__(self, target_hidden, split_model, matcher_device="cuda:0"):
        self.device = split_model.device
        self.matcher_device = matcher_device
        self.target = target_hidden
        self.model = split_model.model_attacker
        self.embed_layer = split_model.embed_layer
        self.vocab_size = self.embed_layer.num_embeddings
        dtype = self.embed_layer.weight.dtype
        self.match_mat = torch.nn.Linear(self.embed_layer.embedding_dim, self.vocab_size, dtype=dtype, bias=False).to(self.matcher_device)
        self.normalize_by_embedding()
        self.idxs = [i for i in range(self.target.shape[1])]

    def normalize_by_embedding(self):
        for i in range(self.vocab_size):
            tmp = self.embed_layer.weight.data[i].norm(p=2)
            if tmp > 1e-3:
                self.match_mat.weight.data[i] = self.embed_layer.weight.data[i] / tmp
            else:
                self.match_mat.weight.data[i] = self.embed_layer.weight.data[i] / 1000

    def get_l2_loss(self, dummy_hidden):
        return (dummy_hidden[:, self.idxs]-self.target[:, self.idxs]).norm(p=2)

    def get_cos_loss_sep(self, dummy_hidden):
        loss = 0
        tj, ti = dummy_hidden.shape[0], dummy_hidden.shape[1]
        for j in range(tj):
            for i in range(ti):
                t0 = (dummy_hidden[j, i] * self.target[j, i]).sum()
                t1 = dummy_hidden[j, i].norm(p=2)
                t2 = self.target[j, i].norm(p=2)
                loss += 1-t0/(t1*t2)
        return loss/tj/ti

    def get_cos_loss_total(self, dummy_hidden):
        t0 = (dummy_hidden * self.target).sum()
        t1 = dummy_hidden.norm(p=2)
        t2 = self.target.norm(p=2)
        return 1-t0/(t1*t2)

    def reconstruct(self, steps):
        self.idxs = [i for i in range(self.target.shape[1])]
        dummy_emd = (torch.rand_like(self.target) * 0.1)
        dummy_emd.requires_grad = True
        optimizer = torch.optim.Adam(params=[dummy_emd], lr=0.01)
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.2, total_iters=steps)
        pbar = tqdm(range(steps))
        for i in pbar:
            out = self.model(inputs_embeds=dummy_emd)['last_hidden_state']
            loss = self.get_cos_loss_sep(out)
            pbar.set_postfix_str(f"step: {i} | loss: {loss.item()}")
            optimizer.zero_grad()
            loss.backward(inputs=dummy_emd, create_graph=False)
            optimizer.step()
            scheduler.step()
        if steps == 0:
            dummy_emd.data = self.target
        dummy_emd = dummy_emd.to(self.matcher_device)
        return self.match_mat(dummy_emd.detach()).max(dim=-1)[1].cpu(), dummy_emd.data


if __name__ == "__main__":
    phi_path = "../pretrained_models/Phi-3-medium-128k-instruct"
    gemma_path = "../pretrained_models/gemma-2-9b-it"
    llama_path = "../pretrained_models/Meta-Llama-3-8B-Instruct"
    mistral_path = "../pretrained_models/Mistral-7B-Instruct-v0.3"
    llama70_path = "../pretrained_models/Meta-Llama-3-70B-Instruct-AWQ"

    model_path = gemma_path
    alpha = 0.5   # [0-0.5]
    total = 42  # {"phi": 40, "llama": 32, "mistral": 32, "gemma": 42, "llama70": 80}
    enable_private = True

    device_ = "cuda:0"
    matcher_device = "cuda:2"
    device_map = {"model.embed_tokens": 0, "model.norm": 2, "lm_head": 2}
    for i in range(total):
        key = f"model.layers.{i}"
        if i < 10:
            val = 0
            device_ = f"cuda:{val}"
        elif i < 40:
            val = 1
        else:
            val = 2
        device_map[key] = val
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device_map, torch_dtype=torch.bfloat16)
    for param in model.model.parameters():
        param.requires_grad = False
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    model.eval()

    # get the hidden state with different numbers of llama blocks
    llm_sp = llms_split(model, tokenizer, n_block=10, device=device_, private_layers=10)
    attack_steps = 200

    all_datas = ['gsm8k', 'boolq', 'humaneval', 'mmlu', 'bbh']
    for data_name in all_datas:
        with open(f'attack_examples/{data_name}.txt', 'r') as file:
            tmp = file.read()
        target_prompts = get_target_prompts(tmp, data_name)

        # init logger
        log_name = model_path.split('/')[-1].split('-')[0].lower()
        if '70B' in model_path:
            log_name += '70b'
        log = logging.getLogger(f'{log_name}_{data_name}')
        log.setLevel(level=logging.DEBUG)
        handler = logging.FileHandler(f'attack_examples/{data_name}_{log_name}{alpha}_{attack_steps}.log', encoding='utf-8', mode='w')
        handler.setLevel(logging.INFO)
        log.addHandler(handler)

        if enable_private:
            llm_sp.reset_alpha(alpha=alpha)
        r1_results, r2_results, rl_results = [], [], []
        for kk, target_prompt in enumerate(target_prompts):
            try:
                target_ = target_prompt if isinstance(target_prompt, list) else [target_prompt]
                gt_tokens = tokenizer(target_, return_tensors='pt', padding=True).input_ids

                if enable_private:
                    target_hidden_state = llm_sp.inference(target_, use_private=True)
                else:
                    target_hidden_state = llm_sp.inference(target_, use_private=False)

                attacker = OptimizationBasedAttacker(target_hidden_state, llm_sp, matcher_device=matcher_device)
                reconstruction, recon_ebd = attacker.reconstruct(attack_steps)
                truth_txt = tokenizer.batch_decode(gt_tokens, skip_special_tokens=True)[0]
                recon_txt = tokenizer.batch_decode(reconstruction, skip_special_tokens=True)[0]

                m = report(reconstructed_user_data=[recon_txt], true_user_data=[truth_txt])
                r1_results.append(m['rouge1'])
                r2_results.append(m['rouge2'])
                rl_results.append(m['rougeL'])
                log.info(f"## alpha: {alpha} | No. {kk}\n")
                log.info(f"ground truth  : {truth_txt}\n")
                log.info(f"reconstruction: {recon_txt}\n")
                log.info(
                    f"METRICS: | ROUGE1: {m['rouge1']:4.2f}| ROUGE2: {m['rouge2']:4.2f} | ROUGE-L: {m['rougeL']:4.2f}\n\n"
                )
            except:
                continue
        log.info(f"******** average with noise {alpha}: | ROUGE1: {get_mean(r1_results)}| ROUGE2: {get_mean(r2_results)} | ROUGE-L: {get_mean(rl_results)}\n\n")
        logging.shutdown()
