import os
import json
import torch
import argparse
from tqdm import tqdm

import torch.nn as nn
from sklearn.cluster import DBSCAN,KMeans
import math
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
log_softmax = nn.LogSoftmax(dim=-1)
nll_loss = nn.NLLLoss(reduction='none')
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


# G1：全加噪声，G2 target加噪声
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str,default='../data/alpaca_data.json')
    parser.add_argument("--save_path", type=str, default='../data/neft_new/final_test/alpaca_no_stable_noise10_240_new')#alpaca_neft2_P代表经历过一次指令微调  alpaca_neft_P2:更新扰动位置  alpaca_neft_P3:更新p2bug,p4:entry,
    parser.add_argument("--model_name_or_path", type=str,default='../llama2/Llama-2-7b-hf')
    parser.add_argument("--max_length", type=int, default=4096)
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--peft", type=bool, default=False)
    parser.add_argument("--peft_path", type=str, default="../mode_saved/kl_mean/low_sorted_1000")
    parser.add_argument("--end_idx", type=int, default=-1)
    parser.add_argument("--prompt", type=str, default='alpaca')

    args = parser.parse_args()
    return args



# Used to get the ppl and emb for the whole input
def get_perplexity_and_embedding_whole_text(tokenizer, model, text, max_length):

    input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    input_embed=model.get_input_embeddings()
    embeddings = input_embed.forward(input_ids)

    with torch.no_grad():
        outputs = model(inputs_embeds=embeddings, labels=input_ids.contiguous())
    loss = outputs.loss

    hidden_states = outputs.hidden_states
    embeddings = hidden_states[-1]
    sentence_embedding = embeddings.mean(dim=1)

    return  sentence_embedding.to('cpu')


    # perplexity = torch.exp(loss)
    #
    # hidden_states = outputs.hidden_states
    # embeddings = hidden_states[-1]
    # sentence_embedding = embeddings.mean(dim=1)


# Used to get the ppl and emb for part of input, used in conditional version, and token-wise los
def diversity_filter(sentences,new_data):
    # 计算TF-IDF向量

    # 初始化一个空列表来存储筛选后的语句
    diverse_sentences = []

    # 遍历每个语句
    for i in range(len(sentences)):
        # 如果这是第一个语句，直接添加到筛选后的列表中
        if i == 0:
            diverse_sentences.append(sentences[i][1])
        else:
            # 计算当前语句与已筛选语句的相似度
            current_similarities = [cosine_similarity(new_data[sentences[i][1]].reshape(1, -1), new_data[diverse_sentences[j]].reshape(1, -1)) for j in range(len(diverse_sentences))]
            # 如果当前语句与已筛选语句的最大相似度小于阈值，将其添加到筛选后的列表中
            if np.max(current_similarities) < 0.90:
                diverse_sentences.append(sentences[i][1])
            if len(diverse_sentences) == 90:
                break

    return diverse_sentences
def diversity_filter2(sentences,new_data):
    # 计算TF-IDF向量

    # 初始化一个空列表来存储筛选后的语句
    diverse_sentences = []

    # 遍历每个语句
    for i in range(len(sentences)):
        # 如果这是第一个语句，直接添加到筛选后的列表中
        if i == 0:
            diverse_sentences.append(sentences[i])
        else:
            # 计算当前语句与已筛选语句的相似度
            current_similarities = [cosine_similarity(new_data[sentences[i]].reshape(1, -1), new_data[diverse_sentences[j]].reshape(1, -1)) for j in range(len(diverse_sentences))]
            # 如果当前语句与已筛选语句的最大相似度小于阈值，将其添加到筛选后的列表中
            if np.max(current_similarities) < 0.90:
                diverse_sentences.append(sentences[i])
            if len(diverse_sentences) == 80:
                break

    return diverse_sentences
def main():

    args = parse_args()
    print(args)

    from transformers import LlamaTokenizer, LlamaForCausalLM
    from peft import PeftModel

    model = LlamaForCausalLM.from_pretrained(args.model_name_or_path, device_map="auto", cache_dir='../cache', output_hidden_states=True,torch_dtype=torch.bfloat16)
    tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path, cache_dir='../cache')
    if args.peft:

        model = PeftModel.from_pretrained(model, args.peft_path)

    model.eval()



    with open(args.data_path, "r") as f:
        data = json.load(f)
    with open('../data/average/kl_final/alpaca_uniform_noise10_240_26000_ID.json', "r") as f:
        data_id = json.load(f)
    start_idx = args.start_idx
    end_idx = args.end_idx if args.end_idx != -1 else len(data)
    sampled_data = data[start_idx:end_idx]

    import time
    strat_time = time.time()
    new_data = []
    loss_list=[]
    for i in tqdm(range(len(sampled_data))):

        data_i = sampled_data[i]
        instruct_i = data_i['instruction']
        output_i = data_i['output']
        if output_i=='':
            output_i=' '


        direct_answer_text = '### Response:' + output_i
        if args.prompt == 'wiz':
            whole_text = instruct_i+'\n\n### Response:'+output_i
            input_i = data_i['input'] if 'input' in data_i.keys() else ''
            if input_i != '':
                whole_text = instruct_i+'\nInput:'+input_i+'\n\n### Response:'+output_i

        elif args.prompt == 'alpaca':
            input_i = data_i['input'] if 'input' in data_i.keys() else ''
            if input_i == '':
                temp_dict = {'instruction':instruct_i}
                promt_to_use = PROMPT_DICT["prompt_no_input"].format_map(temp_dict)
                whole_text = promt_to_use + output_i
                instruct_i = promt_to_use
            else:
                temp_dict = {'instruction':instruct_i,'input':input_i}
                promt_to_use = PROMPT_DICT["prompt_input"].format_map(temp_dict)
                whole_text = promt_to_use + output_i
                instruct_i = promt_to_use

        temp_data_i = {}
        sentence_embedding = get_perplexity_and_embedding_whole_text(tokenizer, model, whole_text,args.max_length)




        new_data.append(sentence_embedding.to(torch.float32).squeeze(0))
        pass

    clustering = KMeans(n_clusters=100, random_state=0).fit(new_data)
    # # 获取聚类结果
    labels = clustering.labels_
    #
    # # 获取簇的数量（忽略噪声点）
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    #
    # # 打印每个簇的样本
    cluster_dict={}
    select_id=[]
    #
    for cluster_id in range(n_clusters):
         samples_in_cluster = np.where(labels == cluster_id)[0]
         cluster_dict[cluster_id]=samples_in_cluster
    select_sorted={}

    for key, value in cluster_dict.items():
        temp=diversity_filter2(value,new_data)
        select_id.extend(temp)

    select_text=[sampled_data[id] for id in select_id  ]
    print(len(select_text))

    with open('../data/average/kl_div_logits_T/kmenas_60.json', "w+") as fw:
        json.dump(select_text, fw, indent=4)

if __name__ == "__main__":
    main()