import os
import json
import torch
import argparse
from tqdm import tqdm
import numpy    as np
import torch.nn as nn
from scipy.special import kl_div, rel_entr
from scipy.stats import entropy
import math
log_softmax = nn.LogSoftmax(dim=-1)
nll_loss = nn.NLLLoss(reduction='none')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

# G1：全加噪声，G2 target加噪声
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str,default='../data/alpaca_data.json')
    parser.add_argument("--save_path", type=str, default='../data/neft_new/final_test/alpaca_uniform_noise10_240')#alpaca_neft2_P代表经历过一次指令微调  alpaca_neft_P2:更新扰动位置  alpaca_neft_P3:更新p2bug,p4:entry,
    parser.add_argument("--model_name_or_path", type=str,default='../llama2/Llama-2-7b-hf')
    parser.add_argument("--max_length", type=int, default=4096)
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--peft", type=bool, default=False)
    parser.add_argument("--peft_path", type=str, default="../mode_saved/kl_mean/low_sorted_1000")
    parser.add_argument("--end_idx", type=int, default=-1)
    parser.add_argument("--prompt", type=str, default='alpaca', help='wiz, alpaca')
    parser.add_argument("--mod", type=str, default='pre', help='pre, cherry')
    parser.add_argument("--seed", type=int, default=240)
    args = parser.parse_args()
    return args


args = parse_args()
print(args)

# Used to get the ppl and emb for the whole input
def get_perplexity_and_embedding_whole_text(tokenizer, model, text,target_span ,max_length,if_neft = False,gaoshi=False,seed=0):

    input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    input_embed=model.get_input_embeddings()
    # a=input_embed.weight
    # b=a.mean()
    # c=a.std()
    #end_token = input_ids.shape[1]
    embeddings = input_embed.forward(input_ids)
    #start_index = text.rfind(target_span)
    start_index1= text.rfind('\n\n### Response')
    start_token_1 = len(tokenizer.encode(text[:start_index1]))
    if if_neft:
        temp_embed = embeddings[0][:start_token_1,:]
        dims = (input_ids.shape[1]) * input_embed.weight.shape[1]
        neft_alpha = 5
        mag_norm = neft_alpha / (dims ** 0.5)
        embeddings[:,:start_token_1,:]+= torch.zeros_like(temp_embed).uniform_(-mag_norm*10, mag_norm*10).to(torch.bfloat16)
    start_index=text.rfind('### Instruction:\n')

    start_token=len(tokenizer.encode(text[:start_index+17]))
    if gaoshi:
        temp_embed = embeddings[0][start_token:start_token_1, :]
        torch.manual_seed(seed)
        # q=temp_embed.mean().item()
        # c=temp_embed.std().item()
        # noise_list=[]
        # for i in range(temp_embed.shape[0]):
        #     sample = temp_embed[i, :]
        #     mean = sample.mean()
        #     std = sample.std()
        #     noise = torch.normal(mean=mean.item(), std=std.item(), size=sample.shape)
        #     noise_list.append(noise)
        noise = torch.zeros_like(temp_embed).uniform_(temp_embed.min().item()*10, temp_embed.max().item()*10).to(torch.bfloat16)
        #noise = torch.normal(mean=temp_embed.mean().item()*10, std=temp_embed.std().item()*10,size=temp_embed.shape).to(torch.bfloat16).to("cuda")
        #noise = torch.normal(mean=0, std=0.017, size=temp_embed.shape).to(torch.bfloat16).to("cuda")
        #noise_list=torch.stack(noise_list).to("cuda")
        embeddings[:,start_token:start_token_1, :] +=noise



    with torch.no_grad():
        outputs = model(inputs_embeds=embeddings, labels=input_ids.contiguous())
    loss = outputs.loss
    logits = outputs.logits
    logits_probabilities = nn.functional.softmax(logits, dim=-1)

    target_logits=logits_probabilities[0,start_token:,:]
    probabilities=[logits_probabilities[0,i,input_ids[0][i]].to('cpu') for i in range(0,input_ids.shape[1])]

    target_probabilities = probabilities[start_token:]



    return loss.to('cpu'),np.array(probabilities),np.array(target_probabilities),logits_probabilities.squeeze(0).to('cpu'),target_logits.squeeze(0).to('cpu')

# Used to get the ppl and emb for part of input, used in conditional version, and token-wise los

def main():



    from transformers import LlamaTokenizer, LlamaForCausalLM
    from peft import PeftModel

    model = LlamaForCausalLM.from_pretrained(args.model_name_or_path, device_map="auto", cache_dir='../cache', output_hidden_states=True,torch_dtype=torch.bfloat16)
    tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path, cache_dir='../cache')
    if args.peft:

        model = PeftModel.from_pretrained(model, args.peft_path)

    model.eval()

    if args.save_path[-3:] != '.pt':
        args.save_path += '.pt'
    # if os.path.exists(args.save_path):
    #     print('save_path exists!')
    #     raise Exception

    with open(args.data_path, "r") as f:
        data = json.load(f)

    start_idx = args.start_idx

    end_idx = args.end_idx if args.end_idx != -1 else len(data)
    sampled_data = data[start_idx:end_idx]
    #sampled_data = data[38770:38772]


    import time
    strat_time = time.time()
    new_data = []
    probabilities_data=[]
    cross_entropy_list = []
    for i in tqdm(range(len(sampled_data))):

        data_i = sampled_data[i]
        instruct_i = data_i['instruction']
        output_i = data_i['output']
        if output_i=='':
            output_i=' '


        direct_answer_text = '### Response:' + output_i
        if args.prompt == 'wiz':
            whole_text = instruct_i+'\n\n### Response:'+output_i
            input_i = data_i['input'] if 'input' in data_i.keys() else ''
            if input_i != '':
                whole_text = instruct_i+'\nInput:'+input_i+'\n\n### Response:'+output_i

        elif args.prompt == 'alpaca':
            input_i = data_i['input'] if 'input' in data_i.keys() else ''
            if input_i == '':
                temp_dict = {'instruction':instruct_i}
                promt_to_use = PROMPT_DICT["prompt_no_input"].format_map(temp_dict)
                whole_text = promt_to_use + output_i
                instruct_i = promt_to_use
            else:
                temp_dict = {'instruction':instruct_i,'input':input_i}
                promt_to_use = PROMPT_DICT["prompt_input"].format_map(temp_dict)
                whole_text = promt_to_use + output_i
                instruct_i = promt_to_use

        temp_data_i = {}
        probabilities={}
        loss_T_list=[]
        logits_T_list=[]
        target_logits_T_list=[]
        logits_probabilities_T_list=[]
        target_logits_probabilities_T_list=[]

        cross_entropy2_list=[]
        kl_list=[]
        if args.mod == 'pre':
            seed=args.seed

            loss_F,logits_F,target_logits_F,logits_probabilities_F,target_logits_probabilities_F = get_perplexity_and_embedding_whole_text(tokenizer, model, whole_text,output_i, args.max_length,if_neft=False,gaoshi=False)
            for i in range(0,3):
                loss_T,logits_T, target_logits_T,logits_probabilities_T,target_logits_probabilities_T = get_perplexity_and_embedding_whole_text(tokenizer, model, whole_text, output_i, args.max_length, if_neft=False,gaoshi=True,seed=seed)
                kl=nn.functional.kl_div(torch.log(target_logits_probabilities_T), target_logits_probabilities_F,
                                      reduction='batchmean')
                kl_list.append(kl)
                #
                # cross_entropy2_list.append(cross_entropy2)
                seed=seed+1
            temp_data_i['kl_target']=(kl_list[0]+kl_list[1]+kl_list[2])/3

        new_data.append(temp_data_i)

        pass

    print('New data len:', len(new_data))
    torch.save(new_data,args.save_path)
    print('Time Used:',(time.time()-strat_time)/60,'(min)')

if __name__ == "__main__":
    main()