"""This file is mainly used to study the role of test play in RAG stage, including:
Q1. Among all test cases, what's the percentage of direct call target? 
direct_call_num / all_test_num
Q2. Among all usage examples, what's the percentage of test? 
test_num / all_usage_num
Q3. Does providing direct call test in RAG performs better than providing indirect call / not providing at all?
"""
import os
import sys
import json
from concurrent.futures import ProcessPoolExecutor, as_completed
# Add the pipeline/ directory to sys.path
script_dir = os.path.dirname(os.path.abspath(__file__))
pipeline_root = os.path.abspath(os.path.join(script_dir, ".."))
sys.path.insert(0, pipeline_root)
from app.task.raw_tasks import RawLocalTask
from app.search.search_backend import SearchBackend

All_TASK_FOLDER = 'REPOCOD/Tasks4Agents'

def cal_direct_call_test(All_TASK_FOLDER):
    """Q1. Among all test cases, what's the percentage of direct call target?
    """
    result_dict = {
        "SUMMARY":{
            'task_num': 0,
            "all_test_num":0,
            "direct_call_num":0,
            'direct_call_per_task':0
        }
    }
    for project_name in os.listdir(All_TASK_FOLDER):
        if project_name not in result_dict.keys():
            # if project_name in ['sphinx', 'xarray']:
            #     continue
            result_dict[project_name] = {
                'task_num': 0,
                "all_test_num":0,
                "direct_call_num":0,
                'direct_call_per_task':0
            }
        project_folder = os.path.join(All_TASK_FOLDER, project_name)
        for project_idx in os.listdir(project_folder):
            task_json_data = os.path.join(project_folder, project_idx, f'{project_idx}.json')
            with open(task_json_data, 'r') as f:
                task_data = json.load(f)
            filtered_test_dict = task_data['filtered_test_dict']
            direct_call_list = filtered_test_dict['0']
            indirect_call_list = filtered_test_dict['1']
            result_dict['SUMMARY']['task_num'] += 1
            result_dict[project_name]['task_num'] += 1
            
            result_dict['SUMMARY']['all_test_num'] += len(direct_call_list) + len(indirect_call_list)
            result_dict['SUMMARY']['direct_call_num'] += len(direct_call_list)
            result_dict[project_name]['all_test_num'] += len(direct_call_list) + len(indirect_call_list)
            result_dict[project_name]['direct_call_num'] += len(direct_call_list)
        result_dict[project_name]['direct_call_per_task'] = result_dict[project_name]['direct_call_num'] / result_dict[project_name]['task_num']
    
    result_dict['SUMMARY']['direct_call_per_task'] = result_dict['SUMMARY']['direct_call_num'] / result_dict['SUMMARY']['task_num']
    return result_dict
            

def process_single_project_instance(project_folder, project_name, project_idx):
    # print(f'=========={project_idx}============')
    task_data_path = os.path.join(project_folder, project_idx, f'{project_idx}.json')
    prj_name = project_idx.split('_')[0]
    prj_path = os.path.join(project_folder, project_idx, prj_name)
    issue_file_path = os.path.join(project_folder, project_idx, 'task_description.md')
    log_file_path = os.path.join(project_folder, project_idx, 'nohup_output.log')
    with open(log_file_path, 'r') as f:
        log_content = f.read()
    task = RawLocalTask(
        task_id=project_idx,
        local_repo=prj_path,
        issue_file=issue_file_path,
        task_data_path=task_data_path
    )
    search_bk_no_test = SearchBackend(project_path=prj_path, task=task, pass_test_files=True)
    tool_result_no_test, search_res_no_test, bool_val_no_test = search_bk_no_test.search_target_usage_example(5)

    test_usage_num = len(task.test_data['0'])
    all_in_test = False
    if (len(search_res_no_test) == 0 and test_usage_num != 0):
        all_in_test = True # means all usage example are in test
    refine_flag = False
    if 'FAILED! We picked the test case:' in log_content and 'PASSED! We picked the test' in log_content:
        refine_flag = True # means we passed after refinement
    if all_in_test and refine_flag:
        print(f'{project_idx}: {task.function_name}, √ all_in_test, √ need refine')
    if not all_in_test and refine_flag:
        print(f'{project_idx}: {task.function_name}, X all_in_test, √ need refine')
    return {
        "project_name": project_name,
        "all_usage_num": len(search_res_no_test) + test_usage_num,
        "test_usage_num": test_usage_num
    }


def cal_direct_call_usage(All_TASK_FOLDER):
    """Q2: Among all usage examples, what's the percentage of test?
    do this parallel
    """
    from collections import defaultdict
    result_dict = defaultdict(lambda: {"all_usage_num": 0, "test_usage_num": 0})
    futures = []
    
    with ProcessPoolExecutor(max_workers=8) as executor:
        for project_name in os.listdir(All_TASK_FOLDER):
            project_folder = os.path.join(All_TASK_FOLDER, project_name)
            for project_idx in os.listdir(project_folder):
                if ('seaborn' not in project_idx):
                    continue
                futures.append(executor.submit(
                    process_single_project_instance, project_folder, project_name, project_idx
                ))
        for future in as_completed(futures):
            result = future.result()
            pname = result["project_name"]
            result_dict[pname]["all_usage_num"] += result["all_usage_num"]
            result_dict[pname]["test_usage_num"] += result["test_usage_num"]
            result_dict["SUMMARY"]["all_usage_num"] += result["all_usage_num"]
            result_dict["SUMMARY"]["test_usage_num"] += result["test_usage_num"]

    return dict(result_dict)
    
         

if __name__ == '__main__':
    pass
    # Q1_result_dict = cal_direct_call_test(All_TASK_FOLDER)
    # # show Q1_result_dict
    # RAG_Q1_result_dict_path = 'pipeline/Test_Analysis/RAG_Q1_result_dict.json'
    # with open(RAG_Q1_result_dict_path, 'w') as f:
    #     json.dump(Q1_result_dict, f, indent=2)
    # for key, value in Q1_result_dict.items():
    #     print(f'{key}:')
    #     for subkey, subval in value.items():
    #         print(f'\t{subkey}: {subval}')
    
    
    Q2_result_dict = cal_direct_call_usage(All_TASK_FOLDER)
    # # show Q2_result_dict
    # RAG_Q2_result_dict_path = 'pipeline/Test_Analysis/RAG_Q2_result_dict.json'
    # with open(RAG_Q2_result_dict_path, 'w') as f:
    #     json.dump(Q2_result_dict, f, indent=2)
    # for key, value in Q2_result_dict.items():
    #     print(f'{key}:')
    #     for subkey, subval in value.items():
    #         print(f'\t{subkey}: {subval}')

    