import json
import re
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import os

output_file = 'math_diverse_solutions_selected.jsonl'
NUM_SOLUTIONS_TO_SELECT = 4

def get_ngrams(text: str, n_sizes: list = [3, 4]) -> set:
    features = set()
    tokens = re.split(r'\W+', text)
    tokens = [t for t in tokens if t]
    
    for n in n_sizes:
        if len(tokens) >= n:
            for i in range(len(tokens) - n + 1):
                ngram = "_".join(tokens[i:i+n])
                features.add(f"ngram{n}:{ngram}")
    return features

def extract_features(text: str) -> set:
    if not text.strip():
        return set()

    return get_ngrams(text)

def select_diverse_solutions_submodular(candidate_solutions: list, num_to_select: int) -> list:
    if not candidate_solutions or num_to_select <= 0:
        return []

    valid_solutions = []
    for i, sol_text in enumerate(candidate_solutions):
        if sol_text and isinstance(sol_text, str):
            valid_solutions.append({'original_index': i, 'text': sol_text})

    if not valid_solutions:
        return []
    
    num_valid_solutions = len(valid_solutions)
    num_to_select = min(num_to_select, num_valid_solutions)

    all_features = [extract_features(data['text']) for data in valid_solutions]

    remaining_indices = list(range(num_valid_solutions))
    selected_indices = []
    covered_features = set()

    for _ in range(num_to_select):
        if not remaining_indices:
            break

        max_marginal_gain = -1
        best_next_local_idx = -1

        for idx in remaining_indices:
            marginal_gain = len(all_features[idx] - covered_features)
            if marginal_gain > max_marginal_gain:
                max_marginal_gain = marginal_gain
                best_next_local_idx = idx
        
        if best_next_local_idx != -1:
            selected_indices.append(best_next_local_idx)
            remaining_indices.remove(best_next_local_idx)
            covered_features.update(all_features[best_next_local_idx])
        else:
            break
            
    final_selected_original_indices = [valid_solutions[i]['original_index'] for i in selected_indices]
    return [candidate_solutions[i] for i in final_selected_original_indices]

def jload(file_path: str) -> list:
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line))
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
    return data

def process_chunk(chunk, output_file, lock):
    for problem_data in tqdm(chunk, leave=False, desc="Processing problem"):
        try:
            candidate_answers = problem_data.get("answer", [])
            if not isinstance(candidate_answers, list):
                continue

            selected_solutions = select_diverse_solutions_submodular(
                candidate_answers, 
                num_to_select=NUM_SOLUTIONS_TO_SELECT
            )
            
            if selected_solutions:
                result = {
                    'problem': problem_data.get("question", ""),
                    'solutions': selected_solutions
                }
                
                with lock:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        json.dump(result, f, ensure_ascii=False)
                        f.write('\n')
        except Exception as e:
            print(f"Error processing a problem: {e}")
            continue

def process_and_save_diverse_solutions():
    with open(output_file, 'w', encoding='utf-8') as f:
        pass

    problems = jload("am_deepseek_distilled_40m_math_merged.jsonl")
    if not problems:
        return
        
    print(f"Total problems loaded: {len(problems)}")

    problems = [p for p in problems if "answer" in p and isinstance(p["answer"], list) and 4 < len(p["answer"]) < 10]
    print(f"Filtered problems (4 < num_answers < 10): {len(problems)}")

    if not problems:
        print("No problems to process after filtering. Exiting.")
        return

    manager = mp.Manager()
    lock = manager.Lock()
    num_processes = 50
    chunk_size = max(1, len(problems) // num_processes)
    chunks = [problems[i:i + chunk_size] for i in range(0, len(problems), chunk_size)]

    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_chunk, output_file=output_file, lock=lock)
        list(tqdm(pool.imap_unordered(process_func, chunks), 
                  total=len(chunks), 
                  desc="Total Progress"))

if __name__ == "__main__":
    process_and_save_diverse_solutions()