import json
import argparse
from colorama import init, Fore, Style
from utils.qa_em import extract_solution
# Initialize colorama
init()


def filter_data(jsonl_file, baseline_jsonl_file, n=3):
    conditions = [
        # {"key": "original_reward", "value": 0},
        {"key": "reward", "value": 1},
    # {"key": "think"},
    ]
    data_baseline = _filter_data(baseline_jsonl_file, conditions, n)
    filtered_data=data_baseline

    print(f"\n{Fore.CYAN}Found {len(filtered_data)} matching records:{Style.RESET_ALL}")
    for i, data in enumerate(filtered_data, 1):
        print(f"\n{Fore.YELLOW}Record {i}:{Style.RESET_ALL}; ID: {data.get('id', 'N/A')}")
        print(f"{Fore.GREEN}Question:{Style.RESET_ALL} {data.get('question', 'N/A')}")
        for condition in conditions:
            print(f"{Fore.RED} {condition['key'].capitalize()}:{Style.RESET_ALL} {data.get(condition['key'], 'N/A')}")
        print(f"{Fore.BLUE}Ground Truth Answer:{Style.RESET_ALL} {data.get('ground_truth', 'N/A')}")
        
        # Print search results with indentation
        search_results = [item['content'] for item in data.get('generation_history', []) if item['role'] == 'search']
        if search_results:
            print(f"{Fore.MAGENTA}Searched Results:{Style.RESET_ALL}")
            for idx, result in enumerate(search_results, 1):
                print(f"  {Fore.MAGENTA}[{idx}]{Style.RESET_ALL} {result}")
        
        final_answer = data.get("final_answer")
        if not final_answer:
            final_answer = extract_solution(data.get("sequences_str"))
            data["final_answer"] = final_answer
        
        print(f"{Fore.CYAN}Answer:{Style.RESET_ALL} {data.get('final_answer', 'N/A')}")
        print(f"{Fore.YELLOW}Reward:{Style.RESET_ALL} {data.get('reward', 'N/A')}")
        print(f"{Fore.YELLOW}No Error:{Style.RESET_ALL} {data.get('no_error', 'N/A')}")
        print(f"{Fore.WHITE}{'-' * 80}{Style.RESET_ALL}")

    return filtered_data

def _filter_data(jsonl_file, conditions, n=3):
    """
    Filter data from JSONL file based on conditions:
    - ground truth problem is not empty
    - reward = 0
    - no_error = true
    
    Args:
        jsonl_file (str): Path to the JSONL file
        n (int): Number of results to show
    """    
    filtered_data = []
    
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line.strip())
                # Check if all conditions are met
                all_conditions_met = True
                for condition in conditions:
                    if condition.get("value") is not None:
                        if data.get(condition["key"]) != condition["value"]:
                            all_conditions_met = False
                            break
                    else:
                        if not data.get(condition["key"]):
                            all_conditions_met = False
                            break
                
                if all_conditions_met:
                    filtered_data.append(data)
                    if len(filtered_data) >= n:
                        break
            except json.JSONDecodeError:
                continue
    return filtered_data
    # Print filtered data

def main():
    parser = argparse.ArgumentParser(description='Filter JSONL data based on specific conditions')
    parser.add_argument('--input', '-i', type=str, required=True,
                      help='Path to the input JSONL file')
    parser.add_argument('--baseline', '-i', type=str, required=True,
                      help='Path to the input JSONL file')
    parser.add_argument('--num', '-n', type=int, default=3,
                      help='Number of results to show (default: 3)')
    
    args = parser.parse_args()
    filter_data(args.input, args.baseline, args.num)

if __name__ == "__main__":
    main() 