import os
import ast
import typing
import tempfile
import subprocess
from typing import Dict, List
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import numpy as np
import sys

from utils.testing_utils import evaluate_test_coverage, run_testcase, evaluate_test_mutation_score, run_testcase_stdio
from utils.parsing_utils import extract_test_cases_v2, extract_python_code, extract_test_cases_stdio


def correct_syntax(ut: str) -> float:
    try:
        
        original_limit = sys.getrecursionlimit()
        try:
            sys.setrecursionlimit(10000) 
            ast.parse(ut)
            return 1.0
        finally:
            sys.setrecursionlimit(original_limit)
            
    except RecursionError:
        return 0.0
    except Exception as e:
        return 0.0


# reward function for solution generation llm (Reward based on LLM-generated unit tests)
def codegen_formatting_reward(
    solution_str: str,
) -> float:
    """
    Check if the solution str contains <reasoning> </reasoning> tags where the reasoning is not empty, 
    followed by the Python code inside the ```python\n ...\n``` code block.
    
    Args:
        solution_str: str, the solution string to check formatting
        
    Returns:
        float, 1.0 if properly formatted, 0.0 otherwise
    """
    import re
    
    # Check for <reasoning> </reasoning> tags with non-empty content
    reasoning_pattern = r'<reasoning>\s*(.+?)\s*</reasoning>'
    reasoning_match = re.search(reasoning_pattern, solution_str, re.DOTALL)
    
    if not reasoning_match:
        return 0.0
    
    # Check if reasoning content is not empty (not just whitespace)
    reasoning_content = reasoning_match.group(1).strip()
    if not reasoning_content:
        return 0.0
    
    # Check for Python code block pattern after reasoning
    # Find the end position of the reasoning tag
    reasoning_end = reasoning_match.end()
    remaining_text = solution_str[reasoning_end:]
    
    # Check for ```python\n ...\n``` code block pattern
    python_code_pattern = r'```python\s*\n(.+?)\n```'
    python_code_match = re.search(python_code_pattern, remaining_text, re.DOTALL)
    
    if not python_code_match:
        return 0.0
    
    # Check if Python code content is not empty
    python_code_content = python_code_match.group(1).strip()
    if not python_code_content:
        return 0.0
    
    # All formatting requirements met
    return 1.0


def solution_generation_reward_stdio(
    data_source: str,
    solution_str: str, 
    ground_truth: str,
    extra_info: Dict,
) -> Dict:
    """
    Reward function for solution generation llm (Reward based on LLM-generated unit tests)
    Args:
        data_source: str, "leetcode_train" or "leetcode_test" or "leetcode_validation"
        solution_str: str, completion generated by code generation llm
        ground_truth: list, list of synthetic test cases generated by unit test generation llm
        extra_info: Dict, containing following keys:
            - gt_test: str, ground truth test cases
            - gt_solution: str, ground_truth solution
    Returns: 
        Dict, containing following keys:
            - score: float, total reward = 0.9* pass_rate + 0.1* formatting_reward
            - pass_rate: float, ratio of the passed synthetic tests
            - formatting_reward: float, binary metric for whether the solution is formatted correctly
            - gt_score: float, ratio of passed ground truth tests
            - syntax_correctness: float, binary metric for whether the solution is syntactically correct
    """
    try:
        solution: str = extract_python_code(solution_str)
        synthetic_tests: List[str] = ground_truth
        gt_test = []
        test_cases = eval(extra_info['gt_test'])
        for inp, out in zip(test_cases['inputs'], test_cases['outputs']):
            gt_test.append({'input': inp, 'output': out})
        gt_solution: str = extra_info["gt_solution"]
        
        # filter out synthetic tests
        def is_passed(test_case):
            return run_testcase_stdio(gt_solution, test_case)['passed']
        
        with ThreadPoolExecutor(max_workers=88) as executor:
            results = list(executor.map(is_passed, synthetic_tests))
        synthetic_tests = [t for t, passed in zip(synthetic_tests, results) if passed]
        
        # compute pass rate of synthetic tests
        def check_test(test_case):
            return run_testcase_stdio(solution, test_case)['passed']
        
        # compute formatting reward
        formatting_reward = codegen_formatting_reward(solution_str)
        
        # compute correctness of the syntax
        syntax_correctness = correct_syntax(solution)
        
        # compute pass rate of synthetic tests
        with ThreadPoolExecutor(max_workers=88) as executor:
            results = list(executor.map(check_test, synthetic_tests))
        n_passed = sum(results)
        pass_rate = n_passed / (len(synthetic_tests) + 1e-6)
        
        if data_source == "trainset":
            return {
                "score": 0.9*pass_rate + 0.1*formatting_reward,
                "pass_rate": pass_rate,
                "formatting_reward": formatting_reward,
                "syntax_correctness": syntax_correctness,
            }
        
        else:
            with ThreadPoolExecutor(max_workers=min(len(gt_test), 88)) as executor:
                results = list(executor.map(check_test, gt_test))
            n_passed = sum(results)
            gt_score = n_passed / (len(gt_test) + 1e-6)
        
            return {
                "score": 0.9*pass_rate + 0.1*formatting_reward,
                "pass_rate": pass_rate,
                "formatting_reward": formatting_reward,
                "gt_score": gt_score,
                "gt_passed": float(gt_score > 0.99),
                "syntax_correctness": syntax_correctness,
            }
    except Exception as e:
        if data_source == "trainset":
            return {
                "score": 0.0,
                "pass_rate": 0.0,
                "formatting_reward": 0.0,
                "syntax_correctness": 0.0,
            }
        else:
            return {
                "score": 0.0,
                "pass_rate": 0.0,
                "formatting_reward": 0.0,
                "gt_score": 0.0,
                "gt_passed": 0.0,
                "syntax_correctness": 0.0,
            }
        
import random


def solution_generation_reward_gtut_stdio(
    data_source: str,
    solution_str: str, 
    ground_truth: str,
    extra_info: Dict,
) -> Dict:
    """
    Reward function for solution generation llm (Reward based on LLM-generated unit tests)
    Args:
        data_source: str, "leetcode_train" or "leetcode_test" or "leetcode_validation"
        solution_str: str, completion generated by code generation llm
        ground_truth: str, ground truth test cases (need to be loaded as a dict)
    """
    try:
        solution: str = extract_python_code(solution_str)
        gt_tests = []
        inp_out = eval(ground_truth)
        for inp, out in zip(inp_out['inputs'], inp_out['outputs']):
            gt_tests.append({'input': inp, 'output': out})
        gt_tests_subsampled = gt_tests[:min(len(gt_tests), 12)]
        
        # filter out synthetic tests
        def is_passed(test_case):
            return run_testcase_stdio(solution, test_case)['passed']
        
        # compute formatting reward
        formatting_reward = codegen_formatting_reward(solution_str)
        # compute correctness of the syntax
        syntax_correctness = correct_syntax(solution)
        
        
        if data_source == "trainset":
            # compute pass rate of synthetic tests
            with ThreadPoolExecutor(max_workers=min(len(gt_tests_subsampled), 88)) as executor:
                results = list(executor.map(is_passed, gt_tests_subsampled))
            n_passed = sum(results)
            pass_rate = n_passed / (len(gt_tests_subsampled) + 1e-6)
            
            return {
                "score": 0.9*pass_rate + 0.1*formatting_reward,
                "code_score": pass_rate,
                "passed": float(pass_rate > 0.99),
                "formatting_reward": formatting_reward,
                "syntax_correctness": syntax_correctness,
            }
        
        else:
            with ThreadPoolExecutor(max_workers=min(len(gt_tests), 64)) as executor:
                results = list(executor.map(is_passed, gt_tests))
            n_passed = sum(results)
            gt_score = n_passed / (len(gt_tests) + 1e-6)
        
            return {
                "score": 0.9*gt_score + 0.1*formatting_reward,
                "code_score": gt_score,
                "passed": float(gt_score > 0.99),
                "formatting_reward": formatting_reward,
                "syntax_correctness": syntax_correctness,
            }
    except Exception as e:
        print(e)
        return {
            "score": 0.0,
            "code_score": 0.0,
            "passed": 0,
            "formatting_reward": 0.0,
            "syntax_correctness": 0.0,
        }


def discrimination_reward_stdio_iter_1(
    data_source: str,
    solution_str: str, 
    ground_truth: str,
    extra_info: Dict,
):
    """
    solution_str: str, completion generated by test generation llm
    ground_truth: List[str], list of solutions sampled by code generation model
    extra_info: Dict, containing following keys:
        - gt_solution: str, ground-truth solution code
    """
    try:
        uts: List[Dict[str]] = extract_test_cases_stdio(solution_str)
        unique_uts = []
        for ut in uts:
            if ut not in unique_uts:
                unique_uts.append(ut)
        candidate_solutions: List[str] = ground_truth
        gt_solution = extra_info['gt_solution']
        
        # compute reasoning token formatting reward
        # 1. compute reasoning token formatting reward
        reasoning_count = solution_str.count("<reasoning>")
        no_degeneration = solution_str.strip().endswith('```')
        formatting_reward = float(reasoning_count == len(uts) and no_degeneration)
        
        print('Number of test cases: ',len(uts))
        
        # compute discrimination reward (# of the instance where the test passes under gt solution and fails under candidate solution / # candidate solutions)
        if len(unique_uts) == 0:
            discrimination_reward = 0.0
            entire_discrimination_reward = 0.0
            clipped_validity = 0.0
            validity_reward = 0.0
            brievity_penalty = 1.0
            duplication_penalty = 0.0
            final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity
            
            return {
                "score": final_score,
                "formatting_reward": formatting_reward,
                "n_unique_test_cases": len(unique_uts),
                "n_test_cases": len(uts),
                "clipped_validity": clipped_validity,
                "validity_ratio": validity_reward,
                "entire_discrimination_reward": entire_discrimination_reward,
                "brievity_penalty": brievity_penalty,
                "duplication_penalty": duplication_penalty
            }
        else:
            # compute brievity penalty
            brievity_penalty = 1 / (len(unique_uts))
            duplication_penalty = 1 - len(unique_uts) / len(uts)
            # 1. GT solution 실행 - 병렬 처리
            def test_gt_solution(ut):
                return run_testcase_stdio(gt_solution, ut)["passed"]
            
            # ThreadPoolExecutor로 GT 테스트 병렬 실행
            max_workers = min(88, mp.cpu_count(), len(unique_uts))
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                gt_results = list(executor.map(test_gt_solution, unique_uts))
            gt_results = np.array(gt_results)
            
            # 2. Validity reward: GT 기준 pass된 비율
            validity_reward = gt_results.mean()
            clipped_validity = gt_results.sum() / max(12.0, len(unique_uts))

            if validity_reward == 0.0:
                entire_discrimination_reward = 0.0
            else:
                idx_passed = np.where(gt_results == 1)[0]
                failed_candidate_indices = set()
                
                test_candidate_pairs = []
                for i in idx_passed:
                    for j, candidate in enumerate(candidate_solutions):
                        test_candidate_pairs.append((i, j, unique_uts[i], candidate))
                
                def test_single_pair(pair):
                    test_idx, candidate_idx, ut, candidate = pair
                    try:
                        passed = run_testcase_stdio(candidate, ut)["passed"]
                        return (test_idx, candidate_idx, not passed)
                    except:
                        return (test_idx, candidate_idx, False)
                
                max_workers = min(88, mp.cpu_count(), len(test_candidate_pairs))
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    results = list(executor.map(test_single_pair, test_candidate_pairs))

                for test_idx, candidate_idx, is_failed in results:
                    if is_failed:
                        failed_candidate_indices.add(candidate_idx)
                
                entire_discrimination_reward = len(failed_candidate_indices) / len(candidate_solutions)

            final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity

            return {
                "score": final_score,
                "formatting_reward": formatting_reward,
                "n_test_cases": len(uts),
                "n_unique_test_cases": len(unique_uts),
                "clipped_validity": clipped_validity,
                "validity_ratio": validity_reward,
                "entire_discrimination_reward": entire_discrimination_reward,
                "brievity_penalty": brievity_penalty,
                "duplication_penalty": duplication_penalty
            }
        
    except:
        return {
                "score": 0.0,
                "formatting_reward": 0.0,
                "n_unique_test_cases": 0,
                "n_test_cases": 0,
                "clipped_validity": 0,
                "validity_ratio": 0,
                "entire_discrimination_reward": 0,
                "brievity_penalty": 0,
                "duplication_penalty": 0
            }
            
        
def discrimination_reward_stdio_iter_1_ablate_denominator(
    data_source: str,
    solution_str: str, 
    ground_truth: str,
    extra_info: Dict,
):
    """
    solution_str: str, completion generated by test generation llm
    ground_truth: List[str], list of solutions sampled by code generation model
    extra_info: Dict, containing following keys:
        - gt_solution: str, ground-truth solution code
    """
    try:
        uts: List[Dict[str]] = extract_test_cases_stdio(solution_str)
        unique_uts = []
        for ut in uts:
            if ut not in unique_uts:
                unique_uts.append(ut)
        candidate_solutions: List[str] = ground_truth
        gt_solution = extra_info['gt_solution']
        
        # compute reasoning token formatting reward
        # 1. compute reasoning token formatting reward
        reasoning_count = solution_str.count("<reasoning>")
        no_degeneration = solution_str.strip().endswith('```')
        formatting_reward = float(reasoning_count == len(uts) and no_degeneration)
        
        print('Number of test cases: ',len(uts))
        
        # compute discrimination reward (# of the instance where the test passes under gt solution and fails under candidate solution / # candidate solutions)
        if len(unique_uts) == 0:
            discrimination_reward = 0.0
            entire_discrimination_reward = 0.0
            clipped_validity = 0.0
            validity_reward = 0.0
            brievity_penalty = 1.0
            duplication_penalty = 0.0
            final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity
            
            return {
                "score": final_score,
                "formatting_reward": formatting_reward,
                "n_unique_test_cases": len(unique_uts),
                "n_test_cases": len(uts),
                "clipped_validity": clipped_validity,
                "validity_ratio": validity_reward,
                "entire_discrimination_reward": entire_discrimination_reward,
                "brievity_penalty": brievity_penalty,
                "duplication_penalty": duplication_penalty
            }
        else:
            # compute brievity penalty
            brievity_penalty = 1 / (len(unique_uts))
            duplication_penalty = 1 - len(unique_uts) / len(uts)

            def test_gt_solution(ut):
                return run_testcase_stdio(gt_solution, ut)["passed"]

            max_workers = min(88, mp.cpu_count(), len(unique_uts))
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                gt_results = list(executor.map(test_gt_solution, unique_uts))
            gt_results = np.array(gt_results)
            
            validity_reward = gt_results.mean()
            clipped_validity = gt_results.sum() / len(unique_uts)

            if validity_reward == 0.0:
                entire_discrimination_reward = 0.0
            else:
                idx_passed = np.where(gt_results == 1)[0]
                failed_candidate_indices = set()
                
                test_candidate_pairs = []
                for i in idx_passed:
                    for j, candidate in enumerate(candidate_solutions):
                        test_candidate_pairs.append((i, j, unique_uts[i], candidate))
                
                def test_single_pair(pair):
                    test_idx, candidate_idx, ut, candidate = pair
                    try:
                        passed = run_testcase_stdio(candidate, ut)["passed"]
                        return (test_idx, candidate_idx, not passed)
                    except:
                        return (test_idx, candidate_idx, False)
                
                max_workers = min(88, mp.cpu_count(), len(test_candidate_pairs))
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    results = list(executor.map(test_single_pair, test_candidate_pairs))
                
                for test_idx, candidate_idx, is_failed in results:
                    if is_failed:
                        failed_candidate_indices.add(candidate_idx)
                
                entire_discrimination_reward = len(failed_candidate_indices) / len(candidate_solutions)

            final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity

            return {
                "score": final_score,
                "formatting_reward": formatting_reward,
                "n_test_cases": len(uts),
                "n_unique_test_cases": len(unique_uts),
                "clipped_validity": clipped_validity,
                "validity_ratio": validity_reward,
                "entire_discrimination_reward": entire_discrimination_reward,
                "brievity_penalty": brievity_penalty,
                "duplication_penalty": duplication_penalty
            }
        
    except:
        return {
                "score": 0.0,
                "formatting_reward": 0.0,
                "n_unique_test_cases": 0,
                "n_test_cases": 0,
                "clipped_validity": 0,
                "validity_ratio": 0,
                "entire_discrimination_reward": 0,
                "brievity_penalty": 0,
                "duplication_penalty": 0
            }



def discrimination_reward_stdio_iter_1_ablate_validity(
    data_source: str,
    solution_str: str, 
    ground_truth: str,
    extra_info: Dict,
):
    """
    solution_str: str, completion generated by test generation llm
    ground_truth: List[str], list of solutions sampled by code generation model
    extra_info: Dict, containing following keys:
        - gt_solution: str, ground-truth solution code
    """
    try:
        uts: List[Dict[str]] = extract_test_cases_stdio(solution_str)
        unique_uts = []
        for ut in uts:
            if ut not in unique_uts:
                unique_uts.append(ut)
        candidate_solutions: List[str] = ground_truth
        gt_solution = extra_info['gt_solution']
        
        # compute reasoning token formatting reward
        # 1. compute reasoning token formatting reward
        reasoning_count = solution_str.count("<reasoning>")
        no_degeneration = solution_str.strip().endswith('```')
        formatting_reward = float(reasoning_count == len(uts) and no_degeneration)
        
        print('Number of test cases: ',len(uts))
        
        # compute discrimination reward (# of the instance where the test passes under gt solution and fails under candidate solution / # candidate solutions)
        if len(unique_uts) == 0:
            discrimination_reward = 0.0
            entire_discrimination_reward = 0.0
            clipped_validity = 0.0
            validity_reward = 0.0
            brievity_penalty = 1.0
            duplication_penalty = 0.0
            final_score = 0.15*formatting_reward + 0.85*entire_discrimination_reward
            
            return {
                "score": final_score,
                "formatting_reward": formatting_reward,
                "n_unique_test_cases": len(unique_uts),
                "n_test_cases": len(uts),
                "clipped_validity": clipped_validity,
                "validity_ratio": validity_reward,
                "entire_discrimination_reward": entire_discrimination_reward,
                "brievity_penalty": brievity_penalty,
                "duplication_penalty": duplication_penalty
            }
        else:
            # compute brievity penalty
            brievity_penalty = 1 / (len(unique_uts))
            duplication_penalty = 1 - len(unique_uts) / len(uts)

            def test_gt_solution(ut):
                return run_testcase_stdio(gt_solution, ut)["passed"]
            
            max_workers = min(88, mp.cpu_count(), len(unique_uts))
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                gt_results = list(executor.map(test_gt_solution, unique_uts))
            gt_results = np.array(gt_results)
            
            validity_reward = gt_results.mean()
            clipped_validity = gt_results.sum() / len(unique_uts)

            if validity_reward == 0.0:
                entire_discrimination_reward = 0.0
            else:
                idx_passed = np.where(gt_results == 1)[0]
                failed_candidate_indices = set()
                
                test_candidate_pairs = []
                for i in idx_passed:
                    for j, candidate in enumerate(candidate_solutions):
                        test_candidate_pairs.append((i, j, unique_uts[i], candidate))
                
                def test_single_pair(pair):
                    test_idx, candidate_idx, ut, candidate = pair
                    try:
                        passed = run_testcase_stdio(candidate, ut)["passed"]
                        return (test_idx, candidate_idx, not passed)
                    except:
                        return (test_idx, candidate_idx, False)
                
                max_workers = min(88, mp.cpu_count(), len(test_candidate_pairs))
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    results = list(executor.map(test_single_pair, test_candidate_pairs))
                
                for test_idx, candidate_idx, is_failed in results:
                    if is_failed:
                        failed_candidate_indices.add(candidate_idx)
                
                entire_discrimination_reward = len(failed_candidate_indices) / len(candidate_solutions)

            final_score = 0.15*formatting_reward + 0.85*entire_discrimination_reward

            return {
                "score": final_score,
                "formatting_reward": formatting_reward,
                "n_test_cases": len(uts),
                "n_unique_test_cases": len(unique_uts),
                "clipped_validity": clipped_validity,
                "validity_ratio": validity_reward,
                "entire_discrimination_reward": entire_discrimination_reward,
                "brievity_penalty": brievity_penalty,
                "duplication_penalty": duplication_penalty
            }
        
    except:
        return {
                "score": 0.0,
                "formatting_reward": 0.0,
                "n_unique_test_cases": 0,
                "n_test_cases": 0,
                "clipped_validity": 0,
                "validity_ratio": 0,
                "entire_discrimination_reward": 0,
                "brievity_penalty": 0,
                "duplication_penalty": 0
            }



def discrimination_reward_stdio_iter_2(
    data_source: str,
    solution_str: str, 
    ground_truth: str,
    extra_info: Dict,
):
    """
    solution_str: str, unit test scripts of multiple test cases generated by test generation llm
    ground_truth: Dict[str, List[str]], list of dictionaries with keys 'iter_0' and 'iter_1'
        'iter_0': List[str], list of 8 code_solutions sampled by code generation model at iteration 0
        'iter_1': List[str], list of 8 code_solutions sampled by code generation model at iteration 1
    extra_info: Dict, containing following keys:
        - gt_solution: str, ground-truth solution code
    """
    try:
        uts: List[Dict[str]] = extract_test_cases_stdio(solution_str)
        unique_uts = []
        for ut in uts:
            if ut not in unique_uts:
                unique_uts.append(ut)
        candidate_solutions: Dict[str, List[str]] = ground_truth
        candidate_solutions_iter_0 = candidate_solutions['iter_0']
        candidate_solutions_iter_1 = candidate_solutions['iter_1']
        gt_solution = extra_info['gt_solution']
        
        if data_source == 'trainset':
            # compute reasoning token formatting reward
            # 1. compute reasoning token formatting reward
            reasoning_count = solution_str.count("<reasoning>")
            no_degeneration = solution_str.strip().endswith('```')
            formatting_reward = float(reasoning_count == len(uts) and no_degeneration)
            
            print('Number of test cases: ',len(uts))
            
            # compute discrimination reward (# of the instance where the test passes under gt solution and fails under candidate solution / # candidate solutions)
            if len(uts) == 0:
                discrimination_reward = 0.0
                entire_discrimination_reward = 0.0
                clipped_validity = 0.0
                validity_reward = 0.0
                brievity_penalty = 1.0
                duplication_penalty = 0.0
                final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity
                
                return {
                    "score": final_score,
                    "formatting_reward": formatting_reward,
                    "n_test_cases": len(uts),
                    "n_unique_test_cases": len(unique_uts),
                    "n_valid_test_cases": 0,
                    "clipped_validity": clipped_validity,
                    "validity_ratio": validity_reward,
                    "entire_discrimination_reward": entire_discrimination_reward,
                    "brievity_penalty": brievity_penalty,
                    "duplication_penalty": duplication_penalty
                }
            else:
                # compute brievity penalty
                brievity_penalty = float(1 / (len(uts)))
                duplication_penalty = float(1 - len(unique_uts) / len(uts))
                
                def test_gt_solution(ut):
                    return run_testcase_stdio(gt_solution, ut)["passed"]
                
                max_workers = min(88, mp.cpu_count(), len(uts))
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    gt_results = list(executor.map(test_gt_solution, unique_uts))
                gt_results = np.array(gt_results)
                
                n_valid_test_cases = int(gt_results.sum())
                validity_reward = float(gt_results.sum()) / len(uts)
                clipped_validity = float(gt_results.sum()) / max(12.0, len(unique_uts))

                if validity_reward == 0.0:
                    entire_discrimination_reward = 0.0
                else:
                    idx_passed = np.where(gt_results == 1)[0]
                    failed_candidate_indices = set()
                    
                    test_candidate_pairs = []
                    for i in idx_passed:
                        for j, candidate in enumerate(candidate_solutions_iter_1):
                            test_candidate_pairs.append((i, j, unique_uts[i], candidate))
                    
                    def test_single_pair(pair):
                        test_idx, candidate_idx, ut, candidate = pair
                        try:
                            passed = run_testcase_stdio(candidate, ut)["passed"]
                            return (test_idx, candidate_idx, not passed)
                        except:
                            return (test_idx, candidate_idx, False)
                    
                    max_workers = min(88, mp.cpu_count(), len(test_candidate_pairs))
                    with ThreadPoolExecutor(max_workers=max_workers) as executor:
                        results = list(executor.map(test_single_pair, test_candidate_pairs))
                    
                    for test_idx, candidate_idx, is_failed in results:
                        if is_failed:
                            failed_candidate_indices.add(candidate_idx)
                    
                    entire_discrimination_reward = len(failed_candidate_indices) / len(candidate_solutions_iter_1)

                final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity

                return {
                    "score": final_score,
                    "formatting_reward": formatting_reward,
                    "n_test_cases": len(uts),
                    "n_unique_test_cases": len(unique_uts),
                    "n_valid_test_cases": n_valid_test_cases,
                    "clipped_validity": clipped_validity,
                    "validity_ratio": validity_reward,
                    "entire_discrimination_reward": entire_discrimination_reward,
                    "brievity_penalty": brievity_penalty,
                    "duplication_penalty": duplication_penalty
                }
        
        elif data_source == 'validationset':
            # 1. compute reasoning token formatting reward
            reasoning_count = solution_str.count("<reasoning>")
            no_degeneration = solution_str.strip().endswith('```')
            formatting_reward = float(reasoning_count == len(uts) and no_degeneration)
            
            print('Number of test cases: ',len(uts))
            
            # if number of test cases is 0, set all rewards to 0
            if len(uts) == 0:
                discrimination_reward = 0.0
                entire_discrimination_reward_iter_0 = 0.0
                entire_discrimination_reward = 0.0
                clipped_validity = 0.0
                validity_reward = 0.0
                brievity_penalty = 1.0
                duplication_penalty = 0.0
                final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity
                
                return {
                    "score": final_score,
                    "formatting_reward": formatting_reward,
                    "n_test_cases": len(uts),
                    "n_valid_test_cases": 0,
                    "n_unique_test_cases": len(unique_uts),
                    "clipped_validity": clipped_validity,
                    "validity_ratio": validity_reward,
                    "entire_discrimination_reward": entire_discrimination_reward,
                    "entire_discrimination_reward_iter_0": entire_discrimination_reward_iter_0,
                    "brievity_penalty": brievity_penalty,
                    "duplication_penalty": duplication_penalty
                }
            else:
                # compute brievity penalty
                brievity_penalty = float(1 / (len(uts)))
                duplication_penalty = float(1 - len(unique_uts) / len(uts))
                
                # filter out invalid test cases
                def test_gt_solution(ut):
                    return run_testcase_stdio(gt_solution, ut)["passed"]
                
                max_workers = min(88, mp.cpu_count(), len(unique_uts))
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    gt_results = list(executor.map(test_gt_solution, unique_uts))
                gt_results = np.array(gt_results)
                
                # Compute validity reward
                n_valid_test_cases = float(gt_results.sum())
                validity_reward = float(gt_results.sum()) / len(uts)
                clipped_validity = float(gt_results.sum()) / max(12.0, len(unique_uts))

                # Compute discrimination reward for iter 1 code samples
                if validity_reward == 0.0:
                    entire_discrimination_reward = 0.0
                else:
                    idx_passed = np.where(gt_results == 1)[0]
                    failed_candidate_indices = set()
                    
                    test_candidate_pairs = []
                    for i in idx_passed:
                        for j, candidate in enumerate(candidate_solutions_iter_1):
                            test_candidate_pairs.append((i, j, unique_uts[i], candidate))
                    
                    def test_single_pair(pair):
                        test_idx, candidate_idx, ut, candidate = pair
                        try:
                            passed = run_testcase_stdio(candidate, ut)["passed"]
                            return (test_idx, candidate_idx, not passed)
                        except:
                            return (test_idx, candidate_idx, False)
                    
                    max_workers = min(88, mp.cpu_count(), len(test_candidate_pairs))
                    with ThreadPoolExecutor(max_workers=max_workers) as executor:
                        results = list(executor.map(test_single_pair, test_candidate_pairs))
                    
                    for test_idx, candidate_idx, is_failed in results:
                        if is_failed:
                            failed_candidate_indices.add(candidate_idx)
                    
                    entire_discrimination_reward = len(failed_candidate_indices) / len(candidate_solutions_iter_1)

                # Compute discrimination reward for iter 0 code samples
                if validity_reward == 0.0:
                    entire_discrimination_reward_iter_0 = 0.0
                else:
                    idx_passed = np.where(gt_results == 1)[0]
                    failed_candidate_indices = set()

                    test_candidate_pairs = []
                    for i in idx_passed:
                        for j, candidate in enumerate(candidate_solutions_iter_0):
                            test_candidate_pairs.append((i, j, unique_uts[i], candidate))
                    
                    def test_single_pair(pair):
                        test_idx, candidate_idx, ut, candidate = pair
                        try:
                            passed = run_testcase_stdio(candidate, ut)["passed"]
                            return (test_idx, candidate_idx, not passed)
                        except:
                            return (test_idx, candidate_idx, False)
                    
                    max_workers = min(88, mp.cpu_count(), len(test_candidate_pairs))
                    with ThreadPoolExecutor(max_workers=max_workers) as executor:
                        results = list(executor.map(test_single_pair, test_candidate_pairs))
                    
                    for test_idx, candidate_idx, is_failed in results:
                        if is_failed:
                            failed_candidate_indices.add(candidate_idx)
                    
                    entire_discrimination_reward_iter_0 = len(failed_candidate_indices) / len(candidate_solutions_iter_0)
                
                final_score = 0.1*formatting_reward + 0.85*entire_discrimination_reward + 0.05*clipped_validity

                return {
                    "score": final_score,
                    "formatting_reward": formatting_reward,
                    "n_test_cases": len(uts),
                    "n_valid_test_cases": n_valid_test_cases,
                    "n_unique_test_cases": len(unique_uts),
                    "clipped_validity": clipped_validity,
                    "validity_ratio": validity_reward,
                    "entire_discrimination_reward_iter_0": entire_discrimination_reward_iter_0,
                    "entire_discrimination_reward": entire_discrimination_reward,
                    "brievity_penalty": brievity_penalty,
                    "duplication_penalty": duplication_penalty
                }
    except:
        if data_source == "trainset":
            return {
                    "score": 0.0,
                    "formatting_reward": 0.0,
                    "n_unique_test_cases": 0,
                    "n_valid_test_cases": 0,
                    "n_test_cases": 0,
                    "clipped_validity": 0,
                    "validity_ratio": 0,
                    "entire_discrimination_reward": 0,
                    "brievity_penalty": 0,
                    "duplication_penalty": 0
                }
        elif data_source == "validationset":
            return {
                "score": 0.0,
                "formatting_reward": 0.0,
                "n_unique_test_cases": 0,
                "n_test_cases": 0,
                "n_valid_test_cases": 0,
                "clipped_validity": 0,
                "validity_ratio": 0,
                "entire_discrimination_reward": 0,
                "entire_discrimination_reward_iter_0": 0,
                "brievity_penalty": 0,
                "duplication_penalty": 0
            }