import json
import argparse
import re   
import os   
from utils.qa_em import extract_solution
import pickle

def get_question_from_seq(sequences_str):
    try:
        user_prompt_part = sequences_str.split('<|im_start|>user\n', 1)[1]
        user_prompt = user_prompt_part.split('<|im_end|>', 1)[0]
        
        if 'Question:' in user_prompt:
            question_text = user_prompt.rsplit('Question:', 1)[1]
            question = question_text.strip().split('\n')[0].strip()
            return question
    except (IndexError, AttributeError):
        pass
    return None

def load_data(file_path):
    data = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                record = json.loads(line)
                question = record.get('question')
                if not question:
                    sequences_str = record.get('sequences_str', '')
                    question = get_question_from_seq(sequences_str)
                
                if question:
                    if question in data:
                        print(f"Warning: Duplicate question found in {file_path}: {question}")
                    data[question] = record
                else:
                    print(f"Warning: Could not extract question from a record in {file_path}")

            except json.JSONDecodeError:
                print(f"Warning: Could not decode JSON from line in {file_path}: {line.strip()}")
    return data



def compare_search_results(file1, file2):
    data1 = load_data(file1)
    data2 = load_data(file2)
    
    count = 0
    for question, record1 in data1.items():
        if question in data2:
            record2 = data2[question]
            
        assistant_messages1 = [ass_message['content'] for ass_message in record1.get('messages', []) if ass_message['role'] == 'assistant']
        assistant_messages1 = '\n'.join(assistant_messages1)
        assistant_messages2 = [ass_message['content'] for ass_message in record2.get('messages', []) if ass_message['role'] == 'assistant']
        assistant_messages2 = '\n'.join(assistant_messages2)
        
        # find all the <search> and </search> in assistant_messages1 and assistant_messages2, return the list of search queries
        search_queries1 = re.findall(r'<search>(.*?)</search>', assistant_messages1, re.DOTALL)
        search_queries2 = re.findall(r'<search>(.*?)</search>', assistant_messages2, re.DOTALL)
        
        if search_queries1 != search_queries2:
            print(f"Question: {question}")
            print(f"Search queries in {file1}: {search_queries1}")
            print(f"Search queries in {file2}: {search_queries2}")
            count += 1
    print(f"Found {count} questions with different search queries.")

        
def main(file1, file2):
    print(f"Loading data from {file1}...")
    data1 = load_data(file1)
    print(f"Loaded {len(data1)} records from {file1}")

    print(f"Loading data from {file2}...")
    data2 = load_data(file2)
    print(f"Loaded {len(data2)} records from {file2}")

    both_reward_1 = []
    different_rewards = []

    for question, record1 in data1.items():
        if question in data2:
            record2 = data2[question]
            
            reward1 = record1.get('reward', 0)
            reward2 = record2.get('reward', 0)

            if reward1 == 1 and reward2 == 1:
                both_reward_1.append({
                    "question": question,
                    "file1": file1,
                    "file2": file2
                })
            elif reward1 != reward2:
                winner = file1 if reward1 == 1 else file2
                loser = file2 if reward1 == 1 else file1
                
                winner_record = record1 if reward1 == 1 else record2
                loser_record = record2 if reward1 == 1 else record1
                
                winner_answer = winner_record.get('answer')
                if not winner_answer:
                    winner_answer = extract_solution(winner_record.get('sequences_str', ''))
                loser_answer = loser_record.get('answer')
                
                if not loser_answer:
                    loser_answer = extract_solution(loser_record.get('sequences_str', ''))
                # extract the search queries from the sequences_str
                search_queries1 = re.findall(r'<search>(.*?)</search>', winner_record.get('sequences_str', ''), re.DOTALL)
                search_queries2 = re.findall(r'<search>(.*?)</search>', loser_record.get('sequences_str', ''), re.DOTALL)

                is_contaminated = False
                if loser_record.get('is_contaminated', False):
                    is_contaminated = True

                different_rewards.append({
                    "question": question,
                    "winner": winner,
                    "loser": loser,
                    "winner_answer": winner_answer,
                    "loser_answer": loser_answer,
                    "winner_reward": winner_record.get('reward'),
                    "loser_reward": loser_record.get('reward'),
                    "winner_search_queries": search_queries1,
                    "loser_search_queries": search_queries2,
                    "is_contaminated": is_contaminated
                })

    print("\n--- Questions where both rewards are 1 ---")
    print(f"Found {len(both_reward_1)} such questions.")
    for item in both_reward_1:
        print(f"  - Question: {item['question']}")
    print("\n--------------------------------")
    print(f"len(data1): {len(data1)}")
    print(f"len(data2): {len(data2)}")
    print(f"Total reward of {file1}: {sum(record1.get('reward', 0) for record1 in data1.values())/len(data1)}")
    print(f"Total reward of {file2}: {sum(record2.get('reward', 0) for record2 in data2.values())/len(data2)}")

    print("\n--- Questions with different rewards ---")
    print(f"Found {len(different_rewards)} such questions.")

    dirty_questions = set()

    for item in different_rewards:
        print(f"  - Question: {item['question']}")
        print(f"    - Winner (reward=1): {item['winner']}")
        print(f"    - Loser (reward=0): {item['loser']}")
        print(f"    - Extracted Answer from winner: {item['winner_answer']}")
        print(f"    - Extracted Answer from loser: [{item['loser_answer']}]")
        print(f"    - {item['loser_answer']=='I don\'t know'}")
        if item['is_contaminated']:
            dirty_questions.add(item['question'])
        elif item["loser_answer"]=='I don\'t know':
            dirty_questions.add(item['question'])
        print(f"    - Winner search queries: {item['winner_search_queries']}")
        print(f"    - Loser search queries: {item['loser_search_queries']}")
        print(f"    - Is contaminated: {item['is_contaminated']}")
    # "eval_results/nq_search/xxxx.json"
    dataset_name1 = file1.split('/')[-2]
    dataset_name2 = file2.split('/')[-2]
    if dataset_name1 == dataset_name2 and args.save_dirty_questions:
        # save the dirty questions to a json file, if exists, append to the original data
        if os.path.exists(f'data/dirty_questions_{dataset_name1}.dump'):
            with open(f'data/dirty_questions_{dataset_name1}.dump', 'rb') as f:
                dirty_questions_old = pickle.load(f)
                print(f"Loaded {len(dirty_questions_old)} dirty questions from {f'data/dirty_questions_{dataset_name1}.dump'}")
            dirty_questions.update(dirty_questions_old)
        with open(f'data/dirty_questions_{dataset_name1}.dump', 'wb') as f:
            pickle.dump(dirty_questions, f)
        print(f"Saved {len(dirty_questions)} dirty questions to {f'data/dirty_questions_{dataset_name1}.dump'}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compare rewards from two evaluation files.")
    parser.add_argument("file1", type=str, help="Path to the first file.")
    parser.add_argument("file2", type=str, help="Path to the second file.")
    parser.add_argument("--save_dirty_questions", action="store_true", help="Save the dirty questions to a json file.")
    args = parser.parse_args()
    
    main(args.file1, args.file2) 
    # compare_search_results(args.file1, args.file2)