'''
Before running the evaluation code, the following files need to be prepared:
    temp_result.json: consists of the following 4 items for each piece of data:
          "question": questions in dataset
          "transferred_answer": outputs generated by edited model, generated in the previous steps
          "modern_answer": outputs generated by unedited ref model (Note that this is different from the "incorrect_answers" in the dataset, because we need to compare the similarity of outputs before and after editing the same model)
          "model_path": just used to identify models edited with different parameters, can be ignored
    */chinese-bert-wwm-ext-trained: Trained classifier
    */Qwen1.5-14B-Chat: The unedited original model, used to calculate perplexity
    */bge-large-zh-v1.5: Embedding model
'''
import json
import torch
import numpy as np
import random
import os
import math
import argparse
from FlagEmbedding import FlagAutoModel  # activate only when calculating semantic preservation and the env should be turned to contain transformer=4.42.0
from tqdm import tqdm
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),"..")))
import qwen2



parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str,default="Qwen1.5-14B-Chat")
parser.add_argument('--main_steer_style', type=str,default="DRC",help="main style")
parser.add_argument('--second_steer_style', type=str,default="None",help="second style")
parser.add_argument('--dataset_name', type=str, default="DRC", help='test dataset')
parser.add_argument('--baseline_name', type=str, default="None", help="caa, lm_steer, vector_prompt, iti, None, ablation")
parser.add_argument('--main_strength', type=float, default=3.0, help='main steer strength')
parser.add_argument('--second_strength', type=float, default=3.0, help='second steer strength')
parser.add_argument('--is_heads', type=int, default=0, help='analysis name')
parser.add_argument('--num_heads', type=int, default=64, help='analysis name')

args = parser.parse_args()
model_name = args.model_name 
print(args)
# import ipdb; ipdb.set_trace()
#  'results_log/Qwen1.5-14B-Chat_evaltqa_gen_zh _mainDRC_secondNone_result.json'
if args.baseline_name == "None":
    file_path = f"results_log/{args.model_name}_eval{args.dataset_name}_main{args.main_steer_style}_second{args.second_steer_style}_strength{args.main_strength}_{args.second_strength}_result.json"

elif args.baseline_name in ["caa", "lm_steer", "vector_prompt", "iti","prompt"]:
    file_path = f"results_log/baseline_result/{args.model_name}/{args.main_steer_style}/{args.baseline_name}/{args.dataset_name}_results.json"  

elif "ablation_svd" in args.baseline_name:
    file_path = f"results_log/{args.model_name}_eval{args.dataset_name}_main{args.main_steer_style}_second{args.second_steer_style}_strength{args.main_strength}_{args.second_strength}_result_{args.baseline_name}.json"
elif "ablation_fixed" in args.baseline_name:
    file_path = f"results_log/{args.model_name}_eval{args.dataset_name}_main{args.main_steer_style}_second{args.second_steer_style}_strength{args.main_strength}_{args.second_strength}_result_{args.baseline_name}.json"
    
if args.is_heads==1:
    file_path = f"results_log/{args.model_name}_eval{args.dataset_name}_main{args.main_steer_style}_second{args.second_steer_style}_strength{args.main_strength}_{args.second_strength}_head{args.num_heads}_result.json"
with open(file_path, 'r', encoding='utf-8') as f:
    data_list = json.load(f)
    
# file_path = f"results_log/{args.model_name}_eval{args.dataset_name}_main{args.main_steer_style}_second{args.second_steer_style}_result.json"
# file_path = "/new_disk1/XXXX-3/projects/DRESS-LLM/results_log/Qwen1.5-14B-Chat_evalDRC_origin_result.json"

print(file_path)
# with open(f"results_log/{model_name}_eval{dataset_name}_main{args.main_steer_style}_second{args.second_steer_style}_result.json", 'w', encoding='utf-8') as new_file:
#         json.dump(output_data, new_file, ensure_ascii=False, indent=4)

d_answer_list = []
# file_path = 'temp_result.json'
with open(file_path, 'r', encoding='utf-8') as file:
    data_list = json.load(file)

# read_model_path = data_list[0]["model_path"]
read_examples = data_list[:]
# for QA in read_examples:
#     del QA["model_path"]

for QA in data_list:
    d_answer_list.append(QA["model_completion"])

#########################################################################################
#                                    Style Intensity                                    #
#########################################################################################


from transformers import BertTokenizer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import BertForSequenceClassification
from keras.utils import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

if torch.cuda.is_available():  
    device = torch.device("cuda")
if "zh" in args.dataset_name:
    tokenizer = BertTokenizer.from_pretrained("/new_disk1/XXXX-3/projects/PretrainModels/chinese-bert-wwm-ext-trained")
else:
    tokenizer = BertTokenizer.from_pretrained("/new_disk1/XXXX-3/projects/PretrainModels/bert-base-uncased-trained")

input_ids = []
for sent in d_answer_list:
    encoded_sent = tokenizer.encode(sent,add_special_tokens = True)
    input_ids.append(encoded_sent)

MAX_LEN = 150
input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", 
                          value=0, truncating="post", padding="post")

attention_masks = []
for sent in input_ids:
    att_mask = [int(token_id > 0) for token_id in sent]
    attention_masks.append(att_mask)

labels = [1] * len(input_ids)

input_ids = torch.tensor(input_ids)
attention_masks = torch.tensor(attention_masks)
labels = torch.tensor(labels)

batch_size = 10
dataset = TensorDataset(input_ids, attention_masks, labels)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)

if "zh" in args.dataset_name:
    model = BertForSequenceClassification.from_pretrained(
        "/new_disk1/XXXX-3/projects/PretrainModels/chinese-bert-wwm-ext-trained",
        num_labels = 2,
        output_attentions = False,
        output_hidden_states = False,
    )
else:
    model = BertForSequenceClassification.from_pretrained(
        "/new_disk1/XXXX-3/projects/PretrainModels/bert-base-uncased-trained",
        num_labels = 2,
        output_attentions = False,
        output_hidden_states = False,
    )
model.cuda()

def flat_accuracy(preds, labels):
    global cnt
    # print(preds)
    for i in range(len(preds)):
        exps = np.exp(preds[i])
        sum = np.sum(exps)
        pred = exps / sum
        read_examples[cnt]['tss_score'] = float(pred[1])
        # print(read_examples[cnt]['tss_score'])
        cnt += 1
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

model.eval()

nb_eval_steps = 0
tss = 0
cnt = 0
for batch in dataloader:
        
    batch = tuple(t.to(device) for t in batch)
    
    b_input_ids, b_input_mask, b_labels = batch
    
    with torch.no_grad():        
        outputs = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask)
        
    logits = outputs[0]
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    
    tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
    # Accumulate the total accuracy.
    tss += tmp_eval_accuracy
    # Track the number of batches
    nb_eval_steps += 1
# Report the final accuracy for this validation run.
tss_result = tss/nb_eval_steps
print(" Style Intensity :", tss_result)


#########################################################################################
#                                    Fluency Score                                      #
#########################################################################################



if model_name == "Qwen2.5-7B-Instruct":
    model_path = "/new_disk1/XXXX-3/projects/PretrainModels/Qwen2.5-7B-Instruct"
elif model_name == "Llama-3-8B-Instruct":
    model_path = "/new_disk1/XXXX-3/projects/PretrainModels/Llama-3-8B-Instruct"
elif model_name == "Qwen1.5-14B-Chat":
    model_path = "/new_disk1/XXXX-3/projects/PretrainModels/Qwen1.5-14B-Chat"

# model_path = '../models/Qwen1.5-14B-Chat'  # XXXX-1
tokenizer = qwen2.Qwen2Tokenizer.from_pretrained(model_path)
model = qwen2.Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

ppls = []
fss = []

# prompt ="""
# 请你对下面的语句作出回复：

# ### Input:
# 有什么事，这么要紧？

# ### Response:
# 以下是我对该语句的回复：
# 有件事需要尽快处理，所以显得比较急。<|endoftext|>
# """
# d_answer_list = [prompt]
for index, item in enumerate(data_list):
    question = item["question"]
    sequence_answer = item["model_completion"]

    # inputs = f"Instruction:\n请你对下面的语句作出回复：\n\n### Input:\n{question}\n\n### Response:\n以下是我对该语句的回复：\n"
    # inputs = tokenizer(inputs, return_tensors='pt')
    # sequence= tokenizer(sequence_answer, return_tensors='pt')
    # inputs = {k: v.to(device) for k, v in inputs.items()}
    # sequence = {k: v.to(device) for k, v in sequence.items()}

    # M = inputs["input_ids"].shape[-1]  # 原句长度
    # N = sequence["input_ids"].shape[-1]              # 生成的长度
    # gen_ids = torch.cat((
    # inputs["input_ids"], sequence["input_ids"]), dim=1).long()                     # shape = [1, L]
    # # import ipdb; ipdb.set_trace()
    # # 2. 再次前向，拿到全句的 logits
    # with torch.no_grad():
    #     outputs   = model(input_ids=gen_ids)             # logits: [1, L, V]
    #     log_probs = torch.nn.functional.log_softmax(
    #         outputs.logits, dim=-1
    #     )  # [1, L, V]
    # # 
    # # 3. Shift 对齐：去掉最后一个 timestep，因为它没有下一个 token
    # shift_log_probs = log_probs[0, :-1, :]               # [L-1, V]
    # shift_labels    = gen_ids[0, 1:]                     # [L-1]

    # answer_lp  = shift_log_probs[ M : M+N , : ]         # 注意：shift_lbls 长度是 M+N，所以切片到 M+N
    # answer_lbl = shift_labels[ M : M+N ]

    # # 4. 计算 PPL
    # idx        = torch.arange(answer_lbl.size(0), device=device)
    # token_lp   = answer_lp[idx, answer_lbl]     # [N]
    # avg_nll    = - token_lp.mean().item()
    # ppl_recalc = math.exp(avg_nll)
    # # 排除 nan
    # if ppl_recalc != ppl_recalc:  # check for NaN
    #     ppl_recalc = 1.0
    # if ppl_recalc>1000:
    #     print(question, sequence_answer)
    # print(f"计算同一句的 PPL = {ppl_recalc:.4f}")
    # ppls.append(ppl_recalc)
    # fs = 1 / (1 + np.log(ppl_recalc))
    # fss.append(fs)


    # print(sequence_answer)
    if len(sequence_answer) == 0:
        ppl = 1.0
        fs = 1 / (1 + np.log(ppl))
        ppls.append(ppl)
        fss.append(fs)  
        print("Empty sequence, skipping...")
        continue
    sequence= tokenizer(sequence_answer, return_tensors='pt')
    sequence = {k: v.to(device) for k, v in sequence.items()}
    with torch.no_grad():
        outputs = model(**sequence) ## ** unpack dictionary so the values of "matched" keys are passed as to the corresponding arguements
        logits = outputs.logits
        logits = logits.view(logits.size(1), logits.size(2)) # size = (seq_len, vocab_size)
        # print(logits.size())
        prob = torch.nn.functional.log_softmax(logits, dim=1) 
        prob_list = []
        for i in range(prob.size(0)-1):
            prob_list.append(prob[i,sequence["input_ids"][0][i+1]].item()) # only consider the "correct" next word probability as the perplexity
        if len(prob_list) == 0:
            ppl = 1.0
        else:
            # standard way of perplexity
            prob_list = np.array(prob_list) 
            sum_p = np.sum(prob_list)
            ppl = np.exp(-1/prob_list.size * sum_p)
            fs = 1 / (1 + np.log(ppl))
        print(f"Perplexity: {ppl}, FluencyScore: {fs}")
        ppls.append(ppl)
        fss.append(fs)  

ppls = np.array(ppls)
print(f"Total num of samples: {ppls.shape}")

ppl_mean_all = np.mean(ppls)
if ppl_mean_all == float('inf'):
    ppl_mean_all = "infinity"
print(f"Mean_perplexity_All: {ppl_mean_all}")

# print(f"FluencyScore_All: {1 / np.log(1 + ppl_mean_all)}")

# ppl_mean_dy = np.mean(ppls[0:200])
# if ppl_mean_dy == float('inf'):
#     ppl_mean_dy = "infinity"
# print(f"Mean_perplexity_First_half: {ppl_mean_dy}")
# print(f"FluencyScore_FirstHalf: {1 / np.log(1 + ppl_mean_dy)}")


# ppl_mean_md = np.mean(ppls[-200:])
# if ppl_mean_md == float('inf'):
#     ppl_mean_md = "infinity"
# print(f"Mean_perplexity_SecondHalf: {ppl_mean_md}")
# print(f"FluencyScore_SecondHalf: {1 / np.log(1 + ppl_mean_md)}")

# fss = np.array(fss)
# print(f"Total num of samples: {fss.shape}")

fs_mean_all = np.mean(fss)
if fs_mean_all == float('inf'):
    fs_mean_all = "infinity"
print(f"FluencyScore_All: {fs_mean_all}")

# fs_mean_dy = np.mean(fss[0:200])
# if fs_mean_dy == float('inf'):
#     fs_mean_dy = "infinity"
# print(f"FluencyScore_FirstHalf: {fs_mean_dy}")


# fs_mean_md = np.mean(fss[-200:])
# if fs_mean_md == float('inf'):
#     fs_mean_md = "infinity"
# print(f"FluencyScore_SecondHalf: {fs_mean_md}")

for i, QA in enumerate(read_examples):
    QA['ppl_score'] = ppls[i]





#########################################################################################
#                                 Semantic Preservation                                 #
#########################################################################################

if "zh" in args.dataset_name:
    model = FlagAutoModel.from_finetuned('/new_disk1/XXXX-3/projects/PretrainModels/bge-large-zh-v1.5',
                                      query_instruction_for_retrieval="为这个句子生成表征以用于检索相关文章：",
                                      use_fp16=True)
else:
    model = FlagAutoModel.from_finetuned('/new_disk1/XXXX-3/projects/PretrainModels/bge-large-en-v1.5',
                                      query_instruction_for_retrieval="Generate a representation for this sentence for retrieval of related articles:",
                                      use_fp16=True)
sp_scores_firstHalf = []
sp_scores_secondHalf = []
sp_scores = []
for i, QA in enumerate(read_examples):


    cands = [QA["model_completion"]]
    refs = [QA["origin_model_completion"]]

    embeddings_1 = model.encode(cands)
    embeddings_2 = model.encode(refs)
    embeddings_1 = embeddings_1.flatten()
    embeddings_2 = embeddings_2.flatten()
    similarity = np.dot(embeddings_1, embeddings_2) / (np.linalg.norm(embeddings_1) * np.linalg.norm(embeddings_2))

    QA["BGE"] = float(similarity)
    # if i < len(read_examples) // 2:
    #     sp_scores_firstHalf.append(QA["BGE"])
    # else:
    #     sp_scores_secondHalf.append(QA["BGE"])
        
    sp_scores.append(QA["BGE"])
    
    
# print(len(sp_scores_firstHalf))
# print(len(sp_scores_secondHalf))
# print("Mean Semantic Preservation Score First Half: ", np.mean(sp_scores_firstHalf))
# print("Mean Semantic Preservation Score Second Half: ", np.mean(sp_scores_secondHalf))


with open(file_path, 'r', encoding='utf-8') as file:
    data_list = json.load(file)
new_dict = {}

new_dict["si_score"] = tss_result
new_dict["sp_score"] = sum(sp_scores) / len(sp_scores)
new_dict["fs_score"] = fs_mean_all
ppl_dict = {}
# ppl_dict["ppl_mean_all"] = ppl_mean_all
# ppl_dict["median_all"] = ppl_med_all
# ppl_dict["mean_modern"] = ppl_mean_dy
# ppl_dict["median_modern"] = ppl_med_dy
# ppl_dict["mean_Shakespeare"] = ppl_mean_md
# ppl_dict["median_Shakespeare"] = ppl_med_md
# new_dict["ppl_score"] = ppl_dict

# new_dict["examples"] = read_examples
data_list[0]["score"] = new_dict


print(new_dict)
# import time
# minute = time.strftime("%H:%M", time.localtime())
# day = time.strftime("%Y-%m-%d", time.localtime())
with open(file_path, 'w', encoding='utf-8') as new_file:
    json.dump(data_list, new_file, ensure_ascii=False, indent=4)

# os.remove("temp_result.json")