import json
import random

def modify_poli_jsonl(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = [json.loads(line) for line in file]

    # Separate entries based on preference
    preference_a = [entry for entry in data if entry['Preference'] == 'a']
    preference_b = [entry for entry in data if entry['Preference'] == 'b']
    print(len(preference_a))
    cnt = 0 
    avg = len(data) // 2
    diff = abs(avg - len(preference_a))
    # Swap Student_Answer and Student_Score for preference 'a'
    for entry in preference_a:
        if cnt < diff:
            entry['Preference'] = 'b'
            cnt += 1
            entry['Student_Answer_a'], entry['Student_Answer_b'] = entry['Student_Answer_b'], entry['Student_Answer_a']
            entry['Student_Score_a'], entry['Student_Score_b'] = entry['Student_Score_b'], entry['Student_Score_a']
        else:
            break
        
    # Combine the modified entries
    combined_data = preference_a + preference_b

    # Write the modified data back to a new jsonl file
    with open('modified_polich.jsonl', 'w', encoding='utf-8') as file:
        for entry in combined_data:
            file.write(json.dumps(entry, ensure_ascii=False) + '\n')
            

def check_acc(input_path, output_path):
    dataset = []
    with open(input_path, 'r', encoding='utf-8') as file:
        dataset = json.loads(file.read())
            
            
    with open(output_path, 'r', encoding='utf-8') as file:
        output = json.loads(file.read())
        
    preference = [entry['Human_preference'] for entry in dataset]
    prediction = [entry['agent_results'] for entry in output]
    
    correct = 0
    
    for i in range(len(prediction)):
        if prediction[i] == preference[i]:
            correct += 1
    print(correct / len(prediction))
    
def check_dataset(input_path):
    dataset = []
    with open(input_path, 'r') as f:
        for line in f.readlines():
            dataset.append(json.loads(line))
    length = []
    for data in dataset[:12]:
        length.append(len(data['Reference'].split()))
    print(sum(length) / len(length))
    
def dataset_analysis():
    base_dir = '../benchmark/final_version/'
    files = ['geo.json', 'psy.json', 'law.json', 'pol.json', 'med.json', 'his.json']
    for file in files:
        human_answers = set()
        model_answers = set()
        with open(base_dir + file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        for entry in data:
            if entry['answer_a_type'] == 'human':
                human_answers.add(entry['student_answer_a'])
            else:
                model_answers.add(entry['student_answer_a'])
            if entry['answer_b_type'] == 'human':
                human_answers.add(entry['student_answer_b'])
            else:
                model_answers.add(entry['student_answer_b'])
        print(f"File: {file}")
        print(f"Length: {len(data)}")
        print(f"Avg Question Length: {sum([len(entry['question'].split()) for entry in data]) / len(data)}")
        print(f"Avg Reference Length: {sum([len(entry['reference'].split()) for entry in data]) / len(data)}")
        print(f"Avg Context Length: {sum([len(entry['context'].split()) for entry in data]) / len(data)}")
        print(f"Avg Answer Length: {sum([len(entry.split()) for entry in human_answers.union(model_answers)]) / len(human_answers.union(model_answers))}")
        print(f"Max Question Length: {max([len(entry['question'].split()) for entry in data])}")
        print(f"Max Reference Length: {max([len(entry['reference'].split()) for entry in data])}")
        print(f"Max Context Length: {max([len(entry['context'].split()) for entry in data])}")
        print(f"Max Answer Length: {max([len(entry.split()) for entry in human_answers.union(model_answers)])}")
        
