"""target of this file:
preclassify the test cases: 
randome selection;
call distance; done
assert / raise, extra;
"""
import json
import os
import shutil
import time
from eval_utils import client, copy_file_to_docker, run_pytest_in_docker, copy_file_from_docker


def replace_target_function(task_obj,
                            target_file_abs_path,
                            replacement = 'raise NotImplementedError'):
    if not os.path.exists(target_file_abs_path):
        raise RuntimeError(f'{target_file_abs_path} doesnt exist!')
    with open(target_file_abs_path, 'r') as f:
        origin_target_file_content = f.read()
    # get correct indentation
    correct_indentation: int = len(task_obj['prompt'][-1]) - len(task_obj['prompt'][-1].lstrip())
    replacement = ' ' * correct_indentation + replacement + '\n'
    full_func_replacement: list = task_obj['prompt'].copy()
    full_func_replacement.append(replacement)
    
    full_GT: list = task_obj['full_function']
    if ''.join(full_GT) not in origin_target_file_content:
        raise RuntimeError(f'Can not find func {task_obj["function_name"]} from {task_obj["start_line"]} to {task_obj["end_line"]} in file {target_file_abs_path}.')

    new_content = origin_target_file_content.replace(''.join(full_GT), ''.join(full_func_replacement))
    return new_content
    

def get_pytest_resuts(task_obj_path):
    task_id_json: str = os.path.basename(task_obj_path)
    with open(task_obj_path, 'r') as f:
        task_obj = json.load(f)
        
    WORKDIR = "/usr/src/app"
    repo_name = task_obj['project_name']
    container_name = task_obj["container_name"]
    container = client.containers.get(container_name)
    
    docker_origin_project_path = f"{WORKDIR}/{repo_name}_modified/"
    docker_project_path_2b_modify = f"{WORKDIR}/{repo_name}/"
    docker_target_file_path = os.path.join(docker_project_path_2b_modify, task_obj['target_module_path'])
    test_files_in_docker = task_obj['relavent_test_path']
    
    target_file_abs_path = os.path.join('repos', task_obj['project_name'], task_obj['target_module_path'])
    print(f'target_file_abs_path: {target_file_abs_path}')
    with open(target_file_abs_path, 'r') as f:
        origin_target_file_content = f.read()
    try:
        # replace local target function body with "raise NotImplementedError"
        new_content = replace_target_function(task_obj, target_file_abs_path)
        # write the new_content to the local target file, and copy to edited_target/
        with open(target_file_abs_path, 'w') as f:
            f.write(new_content)
        os.makedirs('edited_target', exist_ok=True)
        shutil.copy(target_file_abs_path, os.path.join('edited_target', task_id_json.replace('json', 'py')))
        # copy the file to the docker container
        copy_file_to_docker(container, target_file_abs_path, docker_target_file_path)
        # print("========sleep for testing==========")
        # time.sleep(60)
        # print("========sleep for testing==========")
        
        # run the pytest, and save the pytest result as local json.
        os.makedirs('pytest_results_without_GT', exist_ok=True)
        pytest_result_path_local = os.path.join('pytest_results_without_GT', task_id_json)
        pytest_result_path_docker = os.path.join(WORKDIR, 'infer_results', task_id_json)
        pytest_bool = run_pytest_in_docker(
            container_name = container_name,
            project_path = docker_project_path_2b_modify,
            result_file_name = pytest_result_path_docker,
            target_functions_path = test_files_in_docker,
            )
        copy_file_from_docker(container, pytest_result_path_docker, pytest_result_path_local)
        
    finally:
        # recover the local content
        with open(target_file_abs_path, 'w') as f:
            f.write(origin_target_file_content)
        # recover the container file as well.
        copy_file_to_docker(container, target_file_abs_path, docker_target_file_path)
    print(f"pytest results saved to {pytest_result_path_local}")
        

def get_call_distance(pytest_result: dict, task_obj: dict):
    """_summary_
    Args:
        pytest_result (dict): _description_
        task_obj (dict): _description_
        return_dict (dict, optional): _description_. Defaults to {}.
    Returns:
        dict: key: call_dis; value: list[{nodeid and lineno}]
    """
    return_dict = {}
    tests: list = pytest_result['tests']
    for ele in tests:
        test_outcome = ele['outcome']
        node_id = ele['nodeid']
        lineno = ele['lineno']
        
        found = False
        # find the traceback first, can be a list, debug when not found
        if test_outcome not in ['failed', 'error']:
            continue
        for phase in ['setup', 'call', 'teardown']:
            phase_info = ele.get(phase, {})
            if phase_info.get("outcome") != "failed":
                continue
            traceback_candidates = [
                phase_info.get("traceback"),
                phase_info.get("longrepr"),
                phase_info.get("stderr"),
            ]
            
            for traceback in traceback_candidates:
                if not traceback:
                    continue
                # # Case 1: Structured traceback list
                if isinstance(traceback, list):
                    for call_distance, hop in enumerate(traceback):
                        hop_path = hop['path']; hop_lineno = hop['lineno']; hop_message = hop['message']
                        if (hop_lineno == task_obj['end_line'] + 1) and hop_message == 'NotImplementedError':
                            # we use this as the signal of calling the target function: 
                            # (task_obj['target_module_path'] in hop_path) and (hop_lineno == task_obj['end_line'] + 1) and hop_message == 'NotImplementedError'
                            # UPDATE: we remove (task_obj['target_module_path'] in hop_path) due to some incomplete of pytest nodeids, especially in plotly.py
                            found = True
                            if call_distance not in return_dict.keys():
                                return_dict[call_distance] = []
                            return_dict[call_distance].append({'nodeid': node_id, 'lineno': lineno})
                            break
            
                if isinstance(traceback, str):
                    call_distance = parse_traceback_str(traceback, task_obj)
                    if call_distance != 0:
                        found = True
                        if call_distance not in return_dict.keys():
                            return_dict[call_distance] = []
                        return_dict[call_distance].append({'nodeid': node_id, 'lineno': lineno})
                        break
            if found:
                break
        
        if not found:
            print(f"{node_id}: no call distance found. treat call dis as 100. continue")
            call_distance = 100
            if call_distance not in return_dict.keys():
                return_dict[call_distance] = []
            return_dict[call_distance].append({'nodeid': node_id, 'lineno': lineno})
            continue
            
    return return_dict   
    
def parse_traceback_str(traceback_str: str, task_obj: dict) -> int:
    """parse the traceback string, and return the call distance.

    Args:
        traceback_str (str): for example: "[gw5] linux -- Python 3.11.10 /usr/local/bin/python3.11\n040     Model formula:\n041 \n042         .. math:: B_{\\nu}(T) = A \\frac{2 h \\nu^{3} / c^{2}}{exp(h \\nu / k T) - 1}\n043 \n044     Examples\n045     --------\n046     >>> from astropy.modeling import models\n047     >>> from astropy import units as u\n048     >>> bb = models.BlackBody(temperature=5000*u.K)\n049     >>> bb(6000 * u.AA)  # doctest: +FLOAT_CMP\nUNEXPECTED EXCEPTION: NotImplementedError()\nTraceback (most recent call last):\n  File \"/usr/local/lib/python3.11/doctest.py\", line 1355, in __run\n    exec(compile(example.source, filename, \"single\",\n  File \"<doctest astropy.modeling.physical_models.BlackBody[3]>\", line 1, in <module>\n  File \"/usr/src/app/astropy/astropy/modeling/core.py\", line 415, in __call__\n    new_call = make_function_with_signature(\n  File \"/usr/src/app/astropy/astropy/modeling/core.py\", line 392, in __call__\n    return super(cls, self).__call__(*inputs, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/usr/src/app/astropy/astropy/modeling/core.py\", line 1132, in __call__\n    outputs = self._generic_evaluate(evaluate, inputs, fill_value, with_bbox)\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/usr/src/app/astropy/astropy/modeling/core.py\", line 1094, in _generic_evaluate\n    outputs = evaluate(_inputs)\n              ^^^^^^^^^^^^^^^^^\n  File \"/usr/src/app/astropy/astropy/modeling/core.py\", line 980, in evaluate\n    return self.evaluate(*_inputs, *parameters)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File \"/usr/src/app/astropy/astropy/modeling/physical_models.py\", line 162, in evaluate\n    raise NotImplementedError\nNotImplementedError\n/usr/src/app/astropy/astropy/modeling/physical_models.py:49: UnexpectedException"
        task_obj (dict): task dict
    Returns:
        int: call distance
    """
    traceback_list = traceback_str.splitlines()
    start: int = 0
    for idx, ele in enumerate(traceback_list):
        if 'Traceback' in ele and ':' in ele:
            start = idx
            break
    call_distance = 0
    for ele in traceback_list[start:]:
        if 'File' in ele:
            function_name = task_obj['function_name']
            line_num = task_obj['end_line'] + 1
            target_file_path = task_obj['target_module_path']
            if function_name in ele and str(line_num) in ele and target_file_path in ele:
                # treat this as the target call
                return call_distance
            call_distance += 1
    return call_distance
        
        
def get_new_test_dict(old_test_dict: dict):
    return_dict = {}
    for call_distance, test_list in old_test_dict.items():
        if call_distance not in return_dict.keys():
            return_dict[call_distance] = []
        
        base_nodeid_tmp = {}
        for ele in test_list:
            base_nodeid = ele['nodeid'].split('[')[0]
            if base_nodeid not in base_nodeid_tmp.keys():
                base_nodeid_tmp[base_nodeid] = {
                    'nodeid_list': [],
                    "lineno": ele['lineno']
                }
            base_nodeid_tmp[base_nodeid]['nodeid_list'].append(ele['nodeid'])
        
        for base_nodeid, value in base_nodeid_tmp.items():
            value["base_nodeid"] = base_nodeid
            return_dict[call_distance].append(value)
    return return_dict
    
        



if __name__ == "__main__":
    # use these for test: ['astropy_0.json', 'astropy_8.json', 'astropy_35.json', 'scikit-learn_218.json']
    prj2handle = 'pylint'
    # STEP 1: get all pytest results without function body: done
    repocod_data_folder_v6 = 'repocod_data_v6'
    # count = 0
    # for task_json_path in os.listdir(repocod_data_folder):
    #     count += 1
    #     # data control
    #     if prj2handle not in task_json_path:
    #         continue
    #     print(f'========== {count}: {task_json_path} ==========')
    #     task_json_path = os.path.join(repocod_data_folder, task_json_path)
    #     get_pytest_resuts(task_json_path)
    
    
    # STEP 2: analyze call distance based on pytest results: done
    pytest_result_folder = 'pytest_results_without_GT'
    pylint_save_folder = 'repocod_data_v7'
    os.makedirs(pylint_save_folder, exist_ok=True)
    # all_outcome = []
    # for idx, pytest_json in enumerate(os.listdir(pytest_result_folder)):
    #     # use these for testing
    #     if prj2handle not in pytest_json:
    #         continue
    #     print(f'=========={idx}. {pytest_json} ==========')
    #     pytest_path = os.path.join(pytest_result_folder, pytest_json)
    #     task_path = os.path.join(repocod_data_folder_v6, pytest_json)
    #     # get pytest result and task obj
    #     with open(pytest_path, 'r') as f:
    #         pytest_result = json.load(f)
    #     with open(task_path, 'r') as f:
    #         task_obj = json.load(f)
    #     # pytest outcome keys: 'failed', 'error', 'passed', 'skipped', 'xfailed', 'xpassed'
    #     # now we can only consider failed and error.
    #     call_distance_dict = get_call_distance(pytest_result, task_obj)
    #     if not call_distance_dict:
    #         print(f'No call distance found for {pytest_json}, exit to debug.')
    #         exit()
    #     task_obj_v4 = task_obj.copy()
    #     # del task_obj_v4["filtered_test_dict"] # delete the old test classification
    #     task_obj_v4['call_distance_test'] = call_distance_dict
    #     save_v4_path = os.path.join(pylint_save_folder, pytest_json)
    #     with open(save_v4_path, 'w') as f:
    #         json.dump(task_obj_v4, f, indent=4)
    #         print(f"saved to {save_v4_path}")
        
    
    
    # STEP 3: sort the call_distance_test based on call distance from repocod_data_folder_v2: done
    # for idx, task_json in enumerate(os.listdir(pylint_save_folder)):
    #     print(f'==============={idx}. {task_json} ===============')
    #     task_json_path = os.path.join(pylint_save_folder, task_json)
    #     with open(task_json_path, 'r') as f:
    #         task_obj = json.load(f)
    #     call_distance_test = task_obj['call_distance_test']
    #     converted_dict = {int(k): v for k, v in call_distance_test.items()}
    #     call_distance_test_sort = dict(sorted(converted_dict.items()))
    #     task_obj['call_distance_test'] = call_distance_test_sort
    #     with open(task_json_path, 'w') as f:
    #         json.dump(task_obj, f, indent=4)
            
    # STEP 4: combine test with same test function but different inputs: done
    # repocod_data_folder_v7 = 'repocod_data_v7'
    # for idx, task_json in enumerate(os.listdir(pylint_save_folder)):
    #     print(f'==============={idx}. {task_json} ===============')
    #     task_json_path = os.path.join(pylint_save_folder, task_json)
    #     with open(task_json_path, 'r') as f:
    #         task_obj = json.load(f)
    #     call_distance_test = task_obj['call_distance_test']
    #     new_call_distance_test = get_new_test_dict(call_distance_test)
    #     task_obj['call_distance_test'] = new_call_distance_test
    #     save_task_json_path = os.path.join(repocod_data_folder_v7, task_json)
    #     with open(save_task_json_path, 'w') as f:
    #         json.dump(task_obj, f, indent=4)
    #     print(f"saved to {save_task_json_path}")
    
    
    # STEP 5: get summary: done
    summary_dict = {}
    count = 0
    for idx, task_json in enumerate(os.listdir(repocod_data_folder_v6)):
        # print(f'==============={idx}. {task_json}===============')
        count += 1
        task_json_path = os.path.join(repocod_data_folder_v6, task_json)
        with open(task_json_path, 'r') as f:
            task_obj = json.load(f)
        call_distance_test = task_obj['call_distance_test']
        for call_dis, test_list in call_distance_test.items():
            if call_dis not in summary_dict.keys():
                summary_dict[call_dis] = []
            summary_dict[call_dis].extend(test_list)
    
    print(f'total tasks: {count}')
    summary_dict = {int(k): v for k, v in summary_dict.items()}
    summary_dict = dict(sorted(summary_dict.items()))
    for call_dis, test_list in summary_dict.items():
        print(f'{call_dis}: {round(len(test_list) / count, 2)}')
        
# call_dis: average num        
# 1: 7.14
# 2: 5.83
# 3: 6.5
# 4: 7.24
# 5: 7.17
# 6: 5.42
# 7: 8.45
# 8: 3.53
# 9: 2.69
# 10: 3.31
# 11: 2.42
# 12: 1.69
# 13: 1.64
# 14: 0.62
# 15: 0.41
# 16: 0.53
# 17: 0.49
# 18: 0.37
# 19: 0.26
# 20: 0.14
# 21: 0.24
# 22: 0.11
# 23: 0.11
# 24: 0.05
# 25: 0.05
# 26: 0.03
# 27: 0.07
# 28: 0.15
# 29: 0.23
# 30: 0.06
# 31: 0.07
# 32: 0.03
# 33: 0.04
# 34: 0.09
# 35: 0.12
# 36: 0.08
# 37: 0.04
# 38: 0.01
# 39: 0.03
# 40: 0.04
# 41: 0.01
# 42: 0.01
# 43: 0.01
# 44: 0.04
# 45: 0.01
# 46: 0.0
# 47: 0.02
# 48: 0.01
# 49: 0.01
# 50: 0.0
# 51: 0.01
# 53: 0.01
# 54: 0.0
# 55: 0.0
# 58: 0.0
# 62: 0.01