import os
import json
import ast
from pathlib import Path
data_folder= 'repocod_data_v4'


def get_source_code(task_obj, 
                    nodeid, 
                    lineno, 
                    repo_base = 'repos_bk'):
    
    project_name = task_obj['project_name']
    
    if project_name == 'plotly.py' and nodeid.startswith('plotly/tests'):
        nodeid = os.path.join('packages/python/plotly', nodeid)
    
    node_list: list = nodeid.split('::')
    # last element of node_list might have input. remove it.
    node_list[-1] = node_list[-1].split('[')[0]
    file_name: str = node_list[0]
    project_path = os.path.join(repo_base, project_name)
    target_file_path = os.path.join(repo_base, project_name, file_name)
    if not Path(target_file_path).exists():
        print(f"File not found: {target_file_path}. EXIT to debug")
        exit()
    
    with open(target_file_path, 'r') as f:
        target_file_content = f.read()
    if target_file_path.endswith('.rst'): # rst fie
        return nodeid, target_file_content
    
    tree = ast.parse(target_file_content)
    lines = target_file_content.splitlines(keepends=True)
    if len(node_list) == 2: 
        test_func = node_list[1]
        if '.' not in test_func: # file::func
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef) and node.name == test_func:
                    start = node.lineno - 1
                    end = node.end_lineno
                    test_src = ''.join(lines[start:end])
                    return nodeid, test_src
            
        else: # file::module
            if file_name[:-3] == '/'.join(node_list[1].split('.')): # entire module
                print(f'entire module: {nodeid}')
                return nodeid, target_file_content
            
            test_target = test_func.split('.')[-1]
            for node in ast.walk(tree): # a class
                if isinstance(node, ast.ClassDef) and node.name == test_target:
                    start = node.lineno - 1
                    end = node.end_lineno
                    test_src = ''.join(lines[start:end])
                    return nodeid, test_src
            for node in ast.walk(tree): # a top level func
                if isinstance(node, ast.FunctionDef) and node.name == test_target:
                    start = node.lineno - 1
                    end = node.end_lineno
                    test_src = ''.join(lines[start:end])
                    return nodeid, test_src
            
    if len(node_list) == 3: # file::class::func
        test_class = node_list[1]
        test_func = node_list[2]
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef) and node.name == test_class:
                for body_item in node.body:
                    if isinstance(body_item, ast.FunctionDef) and body_item.name == test_func:
                        start = body_item.lineno - 1
                        end = body_item.end_lineno
                        test_src = ''.join(lines[start:end])
                        return nodeid, test_src
                # handle cases when func name is combined with input, such as tests/test_arrow_dataset.py::BaseDatasetTest::test_class_encode_column_in_memory
                for body_item in node.body:
                    if isinstance(body_item, ast.FunctionDef) and (body_item.name in test_func):
                        start = body_item.lineno - 1
                        end = body_item.end_lineno
                        test_src = ''.join(lines[start:end])
                        return nodeid, test_src
        # but there are cases the shown class is not where test belongs. 
        for node in ast.walk(tree): 
            if isinstance(node, ast.FunctionDef) and node.name == test_func:
                start = node.lineno - 1
                end = node.end_lineno
                test_src = ''.join(lines[start:end])
                return nodeid, test_src
        # the function is not in the same file. so we have to search the function in the codebase (xarray/tests/test_conventions.py::TestCFEncodedDataStore:: XXX)
        for dirpath, _, filenames in os.walk(project_path):
            for filename in filenames:
                if filename.endswith(".py"):
                    file_path = os.path.join(dirpath, filename)
                    with open(file_path, "r", encoding="utf-8") as f:
                        source = f.read()
                    try:
                        tree = ast.parse(source)
                    except SyntaxError:
                        continue  # skip files with syntax errors
                    for node in ast.walk(tree):
                        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == test_func:
                            # Requires Python 3.8+ for end_lineno
                            start, end = node.lineno - 1, node.end_lineno
                            test_src = ''.join(lines[start:end])
                            return nodeid, test_src
        
    print(f'[WARNING] Unhandled nodeid: {nodeid}')
    exit()
    return nodeid, None
    

def has_assertion_string(source_code: str, assertion_keywords: list) -> bool:
    for line in source_code.splitlines():
        line = line.strip()
        if line.startswith("#") or line.startswith('"""') or line.startswith("'''"):
            continue  # skip comments and docstrings
        if any(keyword in line for keyword in assertion_keywords):
            return True
    return False


if __name__ == "__main__":
    data_folder = 'repocod_data_v7'
    os.makedirs(data_folder, exist_ok=True)
    # step 1: collecting all test source code.
    # for idx, task_id in enumerate(os.listdir(data_folder)):
    #     if "pylint" not in task_id:
    #         continue
    #     json_path = os.path.join(data_folder, task_id)
    #     print(f'============{idx}. {task_id} =============')
    #     with open(json_path, 'r') as f:
    #         task_obj = json.load(f)
    #     new_call_distance_test: dict = {}
    #     for call_dis, vals in task_obj['call_distance_test'].items():        
    #         if call_dis not in new_call_distance_test.keys():
    #             new_call_distance_test[call_dis] = []
    #         for idx, test in enumerate(vals):
    #             node_id = test['base_nodeid']
    #             # 1. .rst file: read all;
    #             # 2. file::func
    #             # 3. file::class::func
    #             lineno = test['lineno']
    #             nodeid_fixed, src_code = get_source_code(task_obj, node_id, lineno)
    #             test['src_code'] = src_code
    #             new_call_distance_test[call_dis].append(test)
    #     task_obj['call_distance_test'] = new_call_distance_test
    #     v5_save_path = os.path.join(data_folder, task_id)
    #     with open(v5_save_path, 'w') as f:
    #         json.dump(task_obj, f, indent=4)
    
    
    # step 2: do fail signal classify
    fail_sig_list = ['raise ', 'assert', 'pytest.raises', 'pytest.fail', 'self.assert', 'self.fail']
    for idx, task_id in enumerate(os.listdir(data_folder)):
        if "pylint" not in task_id:
            continue
        print(f'============{idx}. {task_id} =============')
        task_json_path = os.path.join(data_folder, task_id)
        with open(task_json_path, 'r') as f:
            task_obj = json.load(f)
        call_distance_test = task_obj['call_distance_test']
        for call_dis, test_list in call_distance_test.items():
            for test in test_list:
                src_code = test['src_code']
                fail_signal = has_assertion_string(src_code, fail_sig_list)
                test['fail_signal'] = fail_signal
        task_obj['call_distance_test'] = call_distance_test
        with open(task_json_path, 'w') as f:
            json.dump(task_obj, f, indent=4)
            
    # step 3, get summary
    # summary_dict = {'assert True': [], 'assert False': []}
    # count = 0
    # for idx, task_id in enumerate(os.listdir(data_folder)):
    #     count += 1
    #     # print(f'============{idx}. {task_id}=============')
    #     task_json_path = os.path.join(data_folder, task_id)
    #     with open(task_json_path, 'r') as f:
    #         task_obj = json.load(f)
    #     call_distance_test = task_obj['call_distance_test']
    #     for call_dis, test_list in call_distance_test.items():
    #         for test in test_list:
    #             if test['fail_signal']:
    #                 summary_dict['assert True'].append(test)
    #             else:
    #                 summary_dict['assert False'].append(test)
    
    # for key, val in summary_dict.items():
    #     print(f'{key}: {len(val) / count}')

# assert True: 60.471311475409834
# assert False: 7.498975409836065