import json
import os
import re


def classify_test_diff_call(task_json: dict, pytest_json: dict, task_json_path: str) -> dict:
    """classify test cases based on the direct call func of the target.
    Since if the direct call of the target in test trace more diverse, '
    it may have a higher coverage of the target.
    Args:
        task_json (dict): task obj
        pytest_json (dict): pytest dict
    Returns:
        dict: new test dict
    """
    old_test_dict: dict = task_json["call_distance_test"]
    return_test_dict = {}
    # TODO: finish this function.
    pytest_info: list = pytest_json["tests"]
    for call_dis, test_list in old_test_dict.items():
        
        if call_dis not in return_test_dict.keys():
            return_test_dict[call_dis] = []
        
        for idx, test in enumerate(test_list): # one test of target
            
            # skip the .rst test here
            base_test_nodeid: str = test["base_nodeid"]
            if ".rst::" in base_test_nodeid:
                continue

            one_specific_nodeid = test["nodeid_list"][0]
            nodeid_in_pytest = None # find the related test in pytest
            for pytest_test in pytest_info:
                if pytest_test["nodeid"] == one_specific_nodeid:
                    nodeid_in_pytest = pytest_test
                    break
            
            if nodeid_in_pytest == None:
                print(f"{one_specific_nodeid}: pytest info not found. EXIT to debug.")
                exit()
            
            traceback = get_traceback(nodeid_in_pytest)
            if traceback == None:
                print(f"{one_specific_nodeid}: traceback not found. EXIT to debug.")
                exit()
                
            direct_call_info = get_direct_call(test, traceback, task_json)
            if direct_call_info['line'] == -1:
                print(f"{one_specific_nodeid}: direct_call_info is not in trace. double check.")
                # continue
            
            test["direct_call_info"] = direct_call_info
            return_test_dict[call_dis].append(test)
        
    return return_test_dict
            
            
            
            
                    

def get_traceback(nodeid_in_pytest: dict):
    """get the traceback of pytest. similar with script in call_distance.py
    Args:
        nodeid_in_pytest (dict): info of the one test info of pytest
    """
    traceback = None
    for phase in ['setup', 'call', 'teardown']:
        phase_info = nodeid_in_pytest.get(phase, {})
        if phase_info.get("outcome") == "failed":
            traceback = [
                phase_info.get("traceback"),
                phase_info.get("longrepr"),
                phase_info.get("stderr"),
            ]
            break
    return traceback


def get_direct_call(
    task_test: dict, 
    traceback_candidates: list | str, 
    task_obj: dict):
    """according to the traceback of the test and the target function, check which function directly
    calls the target function in the traceback, and return the info of this direct-call function.

    Args:
        # one_specific_nodeid (str): pytest nodeid
        traceback (list | str): pytest traceback
        task_obj (dict): task obj, dict
    """
    for traceback in traceback_candidates:
        if not traceback:
            continue
        if isinstance(traceback, str):
            traceback_list = traceback.splitlines()
            start: int = 0
            for idx, ele in enumerate(traceback_list):
                if 'Traceback' in ele and ':' in ele:
                    start = idx
                    break
            function_name = task_obj['function_name']
            line_num = str(task_obj['end_line'] + 1)
            target_file_path = task_obj['target_module_path']
            for idx in range(start, len(traceback_list)):
                line = traceback_list[idx]
                if (
                    'File' in line and
                    target_file_path in line and
                    line_num in line and
                    function_name in line
                ):# treat this as the target call. get the info of previous one. search upwards
                    for back_idx in range(idx - 1, -1, -1):
                        caller_line = traceback_list[back_idx]
                        match = re.search(r'File "(.+?)", line (\d+), in (\w+)', caller_line)
                        if match:
                            return {
                                'file': match.group(1),
                                'line': int(match.group(2)),
                                # 'func': match.group(3) # ignore the func since traceback list don't have it.
                            }
                    break  # once found the target function, stop scanning further
    
        if isinstance(traceback, list):
            for call_distance, hop in enumerate(traceback):
                hop_path = hop['path']; hop_lineno = hop['lineno']; hop_message = hop['message']
                if (
                        # (task_obj['target_module_path'] in hop_path) and ignore this line for plotly
                        (hop_lineno == task_obj['end_line'] + 1) and 
                        hop_message == 'NotImplementedError'
                    ):
                    # we use this as the signal of calling the target function
                    direct_call_info = traceback[call_distance-1]
                    direct_call_file = direct_call_info['path']
                    direct_call_line = direct_call_info['lineno']
                    return {
                                'file': direct_call_file,
                                'line': direct_call_line,
                                # 'func': match.group(3) # ignore the func since traceback list don't have it.
                            }
    
    # this is for other cases.
    return {
                'file': task_test["base_nodeid"],
                'line': -1,
            }
                
                
    
    


pytest_wo_GT_base = "pytest_results_without_GT"
existing_data_folder = "repocod_data_v7"
update_data_folder = "repocod_data_v7_update"
os.makedirs(update_data_folder, exist_ok=True)
special_list = ["astropy_50.json", "scikit-learn_1.json", "scikit-learn_28.json", "scikit-learn_38.json", "scikit-learn_40.json", "scikit-learn_78.json", "scikit-learn_112.json", "scikit-learn_185.json", "scikit-learn_220.json", "scikit-learn_221.json", "scikit-learn_239.json", "scikit-learn_240.json", "scikit-learn_256.json", "scikit-learn_271.json", "scikit-learn_272.json", "scikit-learn_285.json", "scikit-learn_296.json", "scikit-learn_307.json"]

if __name__ == "__main__":
    count = 0
    for task_json in os.listdir(existing_data_folder):
        count += 1
        if "pylint" not in task_json:
            continue
        print(f"========== {count}. {task_json} ==========")
        
        task_json_path = os.path.join(existing_data_folder, task_json)
        pytest_json_path = os.path.join(pytest_wo_GT_base, task_json)
        
        with open(task_json_path, 'r') as f:
            task_obj = json.load(f)
        
        with open(pytest_json_path, 'r') as f:
            pytest_json = json.load(f)
        
        new_call_distance_test = classify_test_diff_call(task_obj, pytest_json, task_json)
        task_obj["call_distance_test"] = new_call_distance_test
        
        save_path = os.path.join(update_data_folder, task_json)
        with open(save_path, 'w') as f:
            json.dump(task_obj, f, indent=4)
        
        # exit()
        
# better remove all rst files in tests.
        