import jsonlines

def jsonl_loader(jsonl_dir):
    # read the content of the evaluation dataset
    with jsonlines.open(jsonl_dir) as reader:
        json_list = list(reader)
    
    return json_list

def check_instruction_following(input_jsonl):
    json_list = jsonl_loader(input_jsonl)
    
    total_constraints = 0
    followed_instructions = 0

    for line in json_list:
        follow_instruction_list = line["follow_instruction_list"]
        total_constraints+=len(follow_instruction_list)

        for element in follow_instruction_list:
            if element == True:
                followed_instructions+=1
    
    print(f"the followed constraints: {followed_instructions}/{total_constraints}")
# check_instruction_following("output/experiment_20241204/gpt_optimized/llama/eval_results_loose.jsonl")

def check_follow_all_instructions(input_jsonl):
    succ = 0
    json_list = jsonl_loader(input_jsonl)
    for line_1 in json_list:
        if line_1["follow_all_instructions"] == True:
            succ+=1

    print(succ)

# check_follow_all_instructions("output/experiment_20241204/gpt_optimized/llama/eval_results_loose.jsonl")

def check_difference(output_1, output_2):
    json_list_1 = jsonl_loader(output_1)
    json_list_2 = jsonl_loader(output_2)
    
    total_constraints = 0
    constraints_change = 0
    follow_all_change = 0
    follow_all_both = 0
    follow_all_1 = 0
    follow_all_2 = 0

    for line_1, line_2 in zip(json_list_1, json_list_2):
        if line_1["follow_all_instructions"]!=line_2["follow_all_instructions"]:
            follow_all_change +=1
        if line_1["follow_all_instructions"]:
            follow_all_1+=1
        if line_2["follow_all_instructions"]: 
            follow_all_2+=1
        if line_1["follow_all_instructions"] and line_2["follow_all_instructions"]:
            follow_all_both +=1

        for result_1, result_2 in zip(line_1["follow_instruction_list"], line_2["follow_instruction_list"]):
            total_constraints+=1
            if result_1 != result_2:
                constraints_change +=1

    print(f"total constraints = {total_constraints}\nconstraints change = {constraints_change}\nfollow all change = {follow_all_change}\nfollow all both = {follow_all_both}\nfollow all 1 = {follow_all_1}\nfollow all 2 = {follow_all_2}")

check_difference("output/experiment_20250410/instruction_tuning/ifeval_benchmark/llama_7b_base_alpaca_prompt/eval_results_strict.jsonl","output/experiment_20250410/instruction_tuning/ifeval_benchmark/llama_7b_v2_s800/eval_results_strict.jsonl")

def reformat(input_jsonl, output_jsonl):
    json_list = jsonl_loader(input_jsonl)
    with jsonlines.open(output_jsonl, mode= "w") as w:
        for line in json_list:
            prompt = line["prompt"]
            response = line["response"].replace("Here is the reformatted text:\n\n","").replace("here is the reformatted text:\n\n","").replace("HERE IS THE REFORMATTED TEXT:\n\n","").replace("here is the reformed text:\n\n","").replace("here is the reformatted text:","").replace("Here is the reformatted text:","").replace("HERE IS THE REFORMATTED TEXT:","")
            response = response.strip()
            w.write({"prompt":prompt, "response":response})

# reformat("output/experiment_20241204/decomposed_prompt/decomposed_prompt_gpt_llama/gpt_decomposed_ifeval_response.jsonl", "output/experiment_20241204/decomposed_prompt/decomposed_prompt_gpt_llama/gpt_decomposed_ifeval_response_reformat.jsonl")
def check_follow_instructions(input):
    jsonl_list = jsonl_loader(input)

    instruction_id_true_lists = {}
    instruction_id_false_lists = {}
    instruction_id_lists = {}
    
    for line in jsonl_list:
        instruction_id_list = line["instruction_id_list"]
        follow_instruction_list = line["follow_instruction_list"]

        for instruction_id, follow_instruction in zip(instruction_id_list, follow_instruction_list):
            instruction_id_ = instruction_id.split(":")
            instruction_id = instruction_id_[0]

            if instruction_id in [*instruction_id_true_lists] and follow_instruction== True:
                instruction_id_true_lists[instruction_id] = instruction_id_true_lists[instruction_id]+1
            
            elif follow_instruction== True:
                instruction_id_true_lists[instruction_id] = 1

            if instruction_id in [*instruction_id_false_lists] and follow_instruction== False:
                instruction_id_false_lists[instruction_id] = instruction_id_false_lists[instruction_id]+1
            
            elif follow_instruction== False:
                instruction_id_false_lists[instruction_id] = 1
            
            if instruction_id in [*instruction_id_lists]:
                instruction_id_lists[instruction_id] = instruction_id_lists[instruction_id]+1
            
            else:
                instruction_id_lists[instruction_id] = 1


    for key, value in instruction_id_lists.items():  
        print(f"{key}: {value}")
    print("=====================")

    for key, value in instruction_id_true_lists.items():  
        print(f"{key}: {value}")    
    
    return instruction_id_lists, instruction_id_true_lists, instruction_id_false_lists