import json
import ast
import re
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
import os

output_file = 'hybrid_diverse_solutions_selected.jsonl'
NUM_SOLUTIONS_TO_SELECT = 4


class EnhancedASTVisitor(ast.NodeVisitor):
    def __init__(self):
        self.features = set()
        self.current_path = []
        self.defined_variables = set()

    def visit(self, node):
        node_type_name = type(node).__name__
        self.current_path.append(node_type_name)
        
        self.features.add(f"ast:type:{node_type_name}")
        
        if len(self.current_path) >= 2:
            path = '->'.join(self.current_path[-2:])
            self.features.add(f"ast:path:{path}")
            
        if isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.With)):
            self.features.add(f"ast:flow:{node_type_name}")
            
        if isinstance(node, ast.Name):
            if isinstance(node.ctx, ast.Store):
                self.defined_variables.add(node.id)
            elif isinstance(node.ctx, ast.Load) and node.id in self.defined_variables:
                self.features.add(f"ast:data_dep:{node.id}")
                
        super().generic_visit(node)
        self.current_path.pop()

def get_ngrams(text: str, n_sizes: list = [3, 4]) -> set:
    features = set()
    tokens = re.split(r'\W+', text)
    tokens = [t for t in tokens if t]
    
    for n in n_sizes:
        if len(tokens) >= n:
            for i in range(len(tokens) - n + 1):
                ngram = "_".join(tokens[i:i+n])
                features.add(f"ngram{n}:{ngram}")
    return features

def extract_combined_features(solution_text: str) -> set:
    code = extract_code_from_markdown(solution_text)
        
    if code.strip():
        try:
            tree = ast.parse(code)
            visitor = EnhancedASTVisitor()
            visitor.visit(tree)
            ast_features = visitor.features
        except Exception:
            ast_features = set()
    else:
        ast_features = set()

    ngram_features = get_ngrams(solution_text)
    
    return ast_features.union(ngram_features)


def extract_code_from_markdown(markdown_text: str) -> str:
    matches = re.findall(r'```(?:[a-zA-Z]*)?\n(.*?)\n```', markdown_text, re.DOTALL)
    if matches:
        return matches[-1].strip()
    return ""

def select_diverse_solutions_submodular(candidate_solutions: list, num_to_select: int) -> list:
    if not candidate_solutions or num_to_select <= 0:
        return []

    valid_solutions = []
    for i, sol_text in enumerate(candidate_solutions):
        if sol_text and isinstance(sol_text, str):
            valid_solutions.append({'original_index': i, 'text': sol_text})

    if not valid_solutions:
        return []
    
    num_valid_solutions = len(valid_solutions)
    num_to_select = min(num_to_select, num_valid_solutions)

    all_features = [extract_combined_features(data['text']) for data in valid_solutions]

    remaining_indices = list(range(num_valid_solutions))
    selected_indices = []
    covered_features = set()

    for _ in range(num_to_select):
        if not remaining_indices:
            break

        max_marginal_gain = -1
        best_next_local_idx = -1

        for idx in remaining_indices:
            marginal_gain = len(all_features[idx] - covered_features)
            if marginal_gain > max_marginal_gain:
                max_marginal_gain = marginal_gain
                best_next_local_idx = idx
        
        if best_next_local_idx != -1:
            selected_indices.append(best_next_local_idx)
            remaining_indices.remove(best_next_local_idx)
            covered_features.update(all_features[best_next_local_idx])
        else:
            break
            
    final_selected_original_indices = [valid_solutions[i]['original_index'] for i in selected_indices]
    return [candidate_solutions[i] for i in final_selected_original_indices]


def jload(file_path: str) -> list:
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line))
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
    return data

def process_chunk(chunk, output_file, lock):
    for problem_data in tqdm(chunk, leave=False, desc="Processing problem"):
        try:
            candidate_answers = problem_data.get("answer", [])
            if not isinstance(candidate_answers, list):
                continue

            selected_solutions = select_diverse_solutions_submodular(
                candidate_answers, 
                num_to_select=NUM_SOLUTIONS_TO_SELECT
            )
            
            if selected_solutions:
                result = {
                    'problem': problem_data.get("question", ""),
                    'solutions': selected_solutions
                }
                
                with lock:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        json.dump(result, f, ensure_ascii=False)
                        f.write('\n')
        except Exception as e:
            print(f"Error processing a problem: {e}")
            continue

def process_and_save_diverse_solutions():
    with open(output_file, 'w', encoding='utf-8') as f:
        pass

    problems = jload("am_deepseek_distilled_40m_code_merged.jsonl")
    if not problems:
        return
        
    print(f"Total problems loaded: {len(problems)}")

    problems = [p for p in problems if "answer" in p and isinstance(p["answer"], list) and 4 < len(p["answer"]) < 10]
    print(f"Filtered problems (4 < num_answers < 10): {len(problems)}")

    if not problems:
        print("No problems to process after filtering. Exiting.")
        return

    manager = mp.Manager()
    lock = manager.Lock()
    num_processes = 50
    chunk_size = max(1, len(problems) // num_processes)
    chunks = [problems[i:i + chunk_size] for i in range(0, len(problems), chunk_size)]

    with mp.Pool(processes=num_processes) as pool:
        process_func = partial(process_chunk, output_file=output_file, lock=lock)
        list(tqdm(pool.imap_unordered(process_func, chunks), 
                  total=len(chunks), 
                  desc="Total Progress"))

if __name__ == "__main__":
    process_and_save_diverse_solutions()