
import ast
import traceback
import sys
import io
from contextlib import contextmanager
import json
import re

class MBPPTester:
    def __init__(self):
        self.timeout = 5  # 5秒超时
        
    @contextmanager
    def capture_output(self):
        """捕获标准输出和错误输出"""
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        try:
            sys.stdout = mystdout = io.StringIO()
            sys.stderr = mystderr = io.StringIO()
            yield mystdout, mystderr
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr
    
    def safe_eval(self, code_str, test_globals=None):
        """安全执行代码"""
        if test_globals is None:
            test_globals = {}
            
        try:
            # 检查代码语法
            ast.parse(code_str)
            
            # 执行代码
            with self.capture_output() as (stdout, stderr):
                exec(code_str, test_globals)
                
            return True, test_globals, stdout.getvalue(), stderr.getvalue()
            
        except SyntaxError as e:
            return False, None, "", f"语法错误: {str(e)}"
        except Exception as e:
            return False, None, "", f"运行时错误: {str(e)}\n{traceback.format_exc()}"
    
    def extract_function_name(self, code):
        """从代码中提取函数名"""
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    return node.name
        except:
            pass
        return None
    
    def parse_assert_statement(self, assert_str):
        """解析assert语句，返回测试表达式"""
        assert_str = assert_str.strip()
        
        # 移除 "assert " 前缀
        if assert_str.startswith('assert '):
            return assert_str[7:].strip()
        
        return assert_str
    
    def run_single_test(self, test_case, test_globals):
        """运行单个测试用例"""
        try:
            # 解析assert语句
            test_expr = self.parse_assert_statement(test_case)
            
            # 执行测试表达式
            result = eval(test_expr, test_globals)
            
            if result:
                return True, "通过"
            else:
                return False, f"断言失败: {test_expr}"
                
        except NameError as e:
            return False, f"函数未定义: {str(e)}"
        except SyntaxError as e:
            return False, f"语法错误: {str(e)}"
        except Exception as e:
            return False, f"执行错误: {str(e)}"
    
    def run_assert_tests(self, code, test_cases):
        """运行assert格式的测试用例"""
        print("="*60)
        print("MBPP 代码测试报告")
        print("="*60)
        
        # 执行代码
        print("1. 执行代码...")
        success, globals_dict, stdout, stderr = self.safe_eval(code)
        
        if not success:
            print(f"❌ 代码执行失败:")
            print(f"   错误: {stderr}")
            return False
            
        func_name = self.extract_function_name(code)
        if not func_name:
            print("❌ 无法从代码中提取函数名")
            return False
            
        print(f"✅ 代码执行成功，函数名: {func_name}")
        
        # 运行测试用例
        print(f"\n2. 运行测试用例...")
        print(f"   总共 {len(test_cases)} 个测试用例")
        
        all_passed = True
        passed_count = 0
        
        for i, test_case in enumerate(test_cases, 1):
            print(f"\n   测试用例 {i}: {test_case}")
            
            test_result, test_msg = self.run_single_test(test_case, globals_dict)
            
            if test_result:
                print(f"   ✅ 结果: {test_msg}")
                passed_count += 1
            else:
                print(f"   ❌ 结果: {test_msg}")
                all_passed = False
                
                # 尝试提供更详细的调试信息
                try:
                    test_expr = self.parse_assert_statement(test_case)
                    if '==' in test_expr:
                        left_expr, right_expr = test_expr.split('==', 1)
                        left_expr = left_expr.strip()
                        right_expr = right_expr.strip()
                        
                        try:
                            left_result = eval(left_expr, globals_dict)
                            right_result = eval(right_expr, globals_dict)
                            print(f"      实际结果: {left_result}")
                            print(f"      期望结果: {right_result}")
                            print(f"      类型: {type(left_result).__name__} vs {type(right_result).__name__}")
                        except:
                            pass
                except:
                    pass
        
        # 总结
        print(f"\n3. 测试总结")
        print(f"   通过测试: {passed_count}/{len(test_cases)}")
        print(f"   成功率: {passed_count/len(test_cases)*100:.1f}%")
        print(f"   总体结果: {'✅ 全部通过' if all_passed else '❌ 存在失败'}")
        
        return all_passed
    
    def compare_implementations(self, your_code, standard_code, test_cases):
        """比较两个实现的结果"""
        print("="*60)
        print("MBPP 代码对比测试")
        print("="*60)
        
        # 测试你的代码
        print("📝 测试你的代码:")
        print("-" * 30)
        your_result = self.run_assert_tests(your_code, test_cases)
        
        print("\n" + "="*60)
        
        # 测试标准代码
        print("📋 测试标准代码:")
        print("-" * 30)
        std_result = self.run_assert_tests(standard_code, test_cases)
        
        print("\n" + "="*60)
        print("🔍 对比结果:")
        print(f"   你的代码: {'✅ 通过' if your_result else '❌ 失败'}")
        print(f"   标准代码: {'✅ 通过' if std_result else '❌ 失败'}")
        
        if your_result == std_result:
            print("   🎯 结果一致")
            return 1
        else:
            print("   ⚠️  结果不一致，可能存在边界情况差异")
            return 0
            
        return your_result


def load_jsonl_multiline(path: str):
    """支持多行格式的 jsonl 文件读取"""
    records = []
    with open(path, "r", encoding="utf-8") as f:
        buffer = ""
        for line in f:
            line = line.strip()
            if not line:
                continue
            buffer += line
            # 尝试解析
            try:
                record = json.loads(buffer)
                records.append(record)
                buffer = ""  # 清空，等待下一个对象
            except json.JSONDecodeError:
                # 说明还没凑成一个完整 JSON
                continue
    return records


import datasets
def load_mbpp_dataset(run_set, dataset_path="/Users/swchen/LocalCodes/SupervisorAgent/smolagents/datasets/mbpp"):
    eval_ds = datasets.load_dataset(
        "json",
        data_files={
            # "train": f"{dataset_path}/data/train*.jsonl",  # 如有训练集
            "test": f"{dataset_path}/data/*mbpp.json"
        },
        split=run_set
    )
    eval_ds = eval_ds.rename_column("prompt", "question")
    eval_ds = eval_ds.rename_column("code", "true_answer")
    # 这里的 eval_ds 是一个 DatasetDict，包含 "test" split
    # 如果你需要直接返回某个 split，可加上：
    # return eval_ds[run_set]
    return eval_ds


def load_jsonl(filename):
    """读取jsonl文件，返回每行的字典组成的list"""
    records = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))
    return records


def test_your_case(file1=None, file2=None):
    """测试你的具体案例"""
    tester = MBPPTester()
    raw_data = load_mbpp_dataset("test")
    ans_1 = load_jsonl_multiline(file1)
    ans_2 = load_jsonl_multiline(file2)
    predictions_1 = {str(item['task_id']): item['answer'] for item in ans_1}
    predictions_2 = {str(item['task_id']): item['answer'] for item in ans_2}
    tokens_1 = {str(item['task_id']): item['total_token'] for item in ans_1}
    tokens_2 = {str(item['task_id']): item['total_token'] for item in ans_2}
    print(len(predictions_1), len(predictions_2))
    sum_1 = 0
    cor_1 = 0
    sum_2 = 0
    cor_2 = 0
    token_1 = 0
    token_2 = 0
    for k, v in predictions_1.items():
        for k2, v2 in predictions_2.items():
            if k == k2:
                sum_1 += 1
                sum_2 += 1
                for i in raw_data:
                    if str(i['task_id']) == k:
                        res_1 = tester.run_assert_tests(v, i['test_list'])
                        res_2 = tester.run_assert_tests(v2, i['test_list'])
                        cor_1 += res_1
                        cor_2 += res_2
                        token_1 += tokens_1[k]
                        token_2 += tokens_2[k2]
                        break
                    
    print(f"File1总任务数: {sum_1}, 通过任务数: {cor_1}, Pass@1: {cor_1/sum_1:.4f}")
    print(f"File1总任务数: {sum_2}, 通过任务数: {cor_2}, Pass@1: {cor_2/sum_2:.4f}")
    print(f"File1总Token数: {token_1}, 平均每任务Token数: {token_1/sum_1:.1f}")
    print(f"File2总Token数: {token_2}, 平均每任务Token数: {token_2/sum_2:.1f}")

def test_only_your_code():
    """只测试你的代码"""
    tester = MBPPTester()
    
    your_code = """
def return_sum(input_dict):
    return sum(input_dict.values())
"""
    
    test_cases = [
        "assert return_sum({'a': 100, 'b':200, 'c':300}) == 600", 
        "assert return_sum({'a': 25, 'b':18, 'c':45}) == 88", 
        "assert return_sum({'a': 36, 'b':39, 'c':49}) == 124"
    ]
    
    result = tester.run_assert_tests(your_code, test_cases)
    print(f"\n🏆 测试结论: {'✅ 全部通过' if result else '❌ 存在问题'}")
    
    return result

def test_custom():
    """自定义测试"""
    tester = MBPPTester()
    
    print("请输入你的代码 (多行输入，输入 'END' 结束):")
    code_lines = []
    while True:
        line = input()
        if line.strip() == 'END':
            break
        code_lines.append(line)
    
    code = '\n'.join(code_lines)
    
    print("\n请输入测试用例 (assert格式，每行一个，输入 'END' 结束):")
    test_cases = []
    while True:
        line = input().strip()
        if line == 'END':
            break
        if line:
            test_cases.append(line)
    
    result = tester.run_assert_tests(code, test_cases)
    print(f"\n🏆 测试结论: {'✅ 全部通过' if result else '❌ 存在问题'}")
    
    return result

if __name__ == "__main__":
    print("选择测试模式:")
    print("1. 测试你的具体案例 (对比两种实现)")
    print("2. 只测试你的代码")
    print("3. 自定义测试")
    
    choice = input("请选择 (1/2/3): ").strip()
    file1 = ""
    file2 = ""
    if choice == "1":
        test_your_case(file1, file2)
    elif choice == "2":
        test_only_your_code()
    elif choice == "3":
        test_custom()
    else:
        print("默认运行测试案例...")
        test_your_case()
