import math
import json

# Step 1: Load data from files a and b
with open('xxx.jsonl', 'r') as f:
    data_after = f.readlines()
    data_after = [json.loads(x) for x in data_after]

with open('xxx.jsonl', 'r') as f:
    data_before = f.readlines()
    data_before = [json.loads(x) for x in data_before]

# Step 2: Process each item
result_data = []

for idx, (response_after_json, response_before_json) in enumerate(zip(data_after, data_before)):
    
    try:
        # 添加try-catch捕获KeyError
        response_after_logprobs = response_after_json["meta_info"]["input_token_logprobs"]
        response_before_logprobs = response_before_json["meta_info"]["input_token_logprobs"]
    except KeyError as e:
        # 打印出问题的数据行
        print(f"KeyError at index {idx}: {e}")
        print(f"Problematic data_after entry: {response_after_json.keys()}")
        print(f"Problematic data_before entry: {response_before_json.keys()}")
        continue  # 跳过当前条目，继续处理下一个

    # Step 3: Locate "assistant" token and skip all tokens before it
    assistant_index_after = None
    assistant_index_before = None
    
    # Find the index of "assistant" in both data sets
    for i, (_, _, text_after) in enumerate(response_after_logprobs):
        if "assistant" in text_after.lower():
            assistant_index_after = i
            break
    
    for i, (_, _, text_before) in enumerate(response_before_logprobs):
        if "assistant" in text_before.lower():
            assistant_index_before = i
            break

    # If no "assistant" found in either, continue
    if assistant_index_after is None or assistant_index_before is None:
        continue

    # Step 4: Create input_id list and label list
    input_id_list = []
    label_list = []

    # Loop through tokens and process
    for i, (logprob_item, logprob_item_before) in enumerate(zip(response_after_logprobs, response_before_logprobs)):
        logprob_after, input_id_after, text_after = logprob_item
        logprob_before, input_id_before, text_before = logprob_item_before

        # Add the input_id to the input_id_list
        input_id_list.append(input_id_after)  # Assuming input_id_after and input_id_before are the same

        # Mark tokens before and including "assistant" as -100 (mask)
        if i <= assistant_index_after:
            label_list.append(-100)
        else:
            # Calculate probabilities from logprobs
            prob_after = math.exp(logprob_after) if logprob_after is not None and logprob_after != 0.0 else 1.0
            prob_before = math.exp(logprob_before) if logprob_before is not None and logprob_before != 0.0 else 1.0
            
            # Calculate delta_probability
            delta_prob = prob_after - prob_before
            
            # Mark tokens with delta_prob < -0.2 as -100, else keep original input_id
            if delta_prob > 0:
                label_list.append(-100)
            else:
                label_list.append(input_id_after)  # Keep the input_id after "assistant" if the condition isn't met

    # Step 5: Add the input_id_list and label_list as new properties
    response_after_json["meta_info"]["input_id_list"] = input_id_list
    response_after_json["meta_info"]["label_list"] = label_list

    # Append the modified response to result data
    result_data.append(response_after_json)

# Step 6: Save the modified data to a new file (file f)
with open('xxx.jsonl', 'w') as f:
    for item in result_data:
        f.write(json.dumps(item) + "\n")


import json

# Step 1: Load data from files e (original dataset)
with open('xxx.jsonl', 'r') as f:
    data_e = f.readlines()
    data_e = [json.loads(x) for x in data_e]

# Step 2: Load the previously modified data (from file d) which includes delta_probability_list
with open('xxx.jsonl', 'r') as f:
    modified_data = f.readlines()
    modified_data = [json.loads(x) for x in modified_data]

# Step 3: Merge the delta_probability_list from modified_data into data_e
for idx, original_data in enumerate(data_e):
    if idx < len(modified_data):
        # Adding the delta_probability_list from the modified data to the original data
        original_data["mask_input_id"] = modified_data[idx]["meta_info"].get("input_id_list", [])
        original_data["mask_label_id"] = modified_data[idx]["meta_info"].get("label_list", [])

# Step 4: Save the merged data into a new file f
with open('xxx.jsonl', 'w') as f:
    for item in data_e:
        f.write(json.dumps(item) + "\n")