
from collections import defaultdict

from src.io_utils import Tools
from src.syntax_utils import is_compilable, get_fn_name

STOP_TOKEN = ['\nclass', '\ndef', '\n


def remove_assert_from_test(test_str):
    test_str = test_str.strip()

    if test_str.startswith("assert"):
        return " " + test_str[len("assert"):]

    return test_str


class PostProcessor:
    @staticmethod
    def parquet_map_task_id_for_solution_and_test(input_parquet_path):
        database = dict()
        result = []
        raw_problems = Tools.load_parquet(input_parquet_path)

        test_cases_by_task = defaultdict(list)

        for each_sample in raw_problems:
            database[each_sample['prompt_codegen']] = str(each_sample['task_id'])
            syntax_correct_num = 0
            for sol_id,pre in enumerate(each_sample['code_output']):
                test_final_code = each_sample['prompt_codegen'] + pre
                test_is_compilable = is_compilable(test_final_code)
                if test_is_compilable:
                    func_name = get_fn_name(test_final_code)
                    syntax_correct_num += 1
                    result.append({
                        'task_id': str(each_sample['task_id']),
                        'sol_id' : sol_id,
                        'prompt': each_sample['prompt_codegen'],
                        'entry_point': func_name,
                        'completion': pre
                    })

            if syntax_correct_num == 0:
                result.append({
                        'task_id': str(each_sample['task_id']),
                        'prompt': each_sample['prompt_codegen'],
                        'entry_point': "",
                        'completion': 'empty solution here, execution will fail'
                    })
            for test_pre in each_sample['test_output']:
                test_case_final_code = each_sample['prompt_testgen'] + test_pre
                test_is_compilable = is_compilable(test_case_final_code)
                if test_is_compilable:
                    func_name = get_fn_name(test_case_final_code)
                    test_cases = PostProcessor.test_case_extract(test_pre, func_name)
                    test_cases_by_task[str(each_sample['task_id'])].append(test_cases)
                else:
                    test_cases_by_task[str(each_sample['task_id'])].append(['assert True'])
        return result, len(database), test_cases_by_task

    @staticmethod
    def parquet_map_task_id_for_find_correct_test(input_parquet_path):
        database = dict()
        result = []
        raw_problems = Tools.load_parquet(input_parquet_path)

        test_cases_by_task = defaultdict(list)

        for each_sample in raw_problems:
            database[each_sample['prompt_codegen']] = str(each_sample['task_id'])
            syntax_correct_num = 0
            this_highest_code = each_sample['highest_code']
            code_output_list = [this_highest_code]
            for sol_id,pre in enumerate(code_output_list):
                test_final_code = each_sample['prompt_codegen'] + pre
                test_is_compilable = is_compilable(test_final_code)
                if test_is_compilable:
                    func_name = get_fn_name(test_final_code)
                    syntax_correct_num += 1
                    result.append({
                        'task_id': str(each_sample['task_id']),
                        'sol_id' : sol_id,
                        'prompt': each_sample['prompt_codegen'],
                        'entry_point': func_name,
                        'completion': pre
                    })

            if syntax_correct_num == 0:
                result.append({
                        'task_id': str(each_sample['task_id']),
                        'prompt': each_sample['prompt_codegen'],
                        'entry_point': "",
                        'completion': 'empty solution here, execution will fail'
                    })
            for test_pre in each_sample['ranked_test']:
                test_pre = remove_assert_from_test(test_pre)
                test_case_final_code = each_sample['prompt_testgen'] + test_pre
                test_is_compilable = is_compilable(test_case_final_code)
                if test_is_compilable:
                    func_name = get_fn_name(test_case_final_code)
                    test_cases = PostProcessor.test_case_extract(test_pre, func_name)
                    test_cases_by_task[str(each_sample['task_id'])].append(test_cases)
                else:
                    test_cases_by_task[str(each_sample['task_id'])].append(['assert True'])
        return result, len(database), test_cases_by_task

    @staticmethod
    def parquet_map_task_id_for_find_correct_code_with_correct_test(input_parquet_path):
        database = dict()
        result = []
        raw_problems = Tools.load_parquet(input_parquet_path)

        test_cases_by_task = defaultdict(list)

        for each_sample in raw_problems:
            database[each_sample['prompt_codegen']] = str(each_sample['task_id'])
            syntax_correct_num = 0

            this_lowest_code = each_sample['lowest_code']
            code_output_list = each_sample['output']
            code_output_list.append(this_lowest_code)
            for sol_id,pre in enumerate(code_output_list):
                test_final_code = each_sample['prompt_codegen'] + pre
                test_is_compilable = is_compilable(test_final_code)
                if test_is_compilable:
                    func_name = get_fn_name(test_final_code)
                    syntax_correct_num += 1
                    result.append({
                        'task_id': str(each_sample['task_id']),
                        'sol_id' : sol_id,
                        'prompt': each_sample['prompt_codegen'],
                        'entry_point': func_name,
                        'completion': pre
                    })

            if syntax_correct_num == 0:
                result.append({
                        'task_id': str(each_sample['task_id']),
                        'prompt': each_sample['prompt_codegen'],
                        'entry_point': "",
                        'completion': 'empty solution here, execution will fail'
                    })
            for test_pre in each_sample['correct_test']:
                test_pre = remove_assert_from_test(test_pre)
                test_case_final_code = each_sample['prompt_testgen'] + test_pre
                test_is_compilable = is_compilable(test_case_final_code)
                if test_is_compilable:
                    func_name = get_fn_name(test_case_final_code)
                    test_cases = PostProcessor.test_case_extract(test_pre, func_name)
                    test_cases_by_task[str(each_sample['task_id'])].append(test_cases)
                else:
                    test_cases_by_task[str(each_sample['task_id'])].append(['assert True'])
        return result, len(database), test_cases_by_task



    @staticmethod
    def map_task_id_for_solution(predict_path, source_path):
        database = dict()
        raw_problems = Tools.load_tasks(source_path)
        for task_id in raw_problems.keys():
            database[raw_problems[task_id]['prompt']] = raw_problems[task_id]

        result = []
        predictions = Tools.load_jsonl(predict_path)
        for pre in predictions:
            task = database[pre['prompt']]
            if not pre['samples']:
                result.append({
                    'task_id': task['task_id'],
                    'prompt': pre['prompt'],
                    'test': task['test'],
                    'entry_point': task['entry_point'],
                    'completion': 'empty solution here, execution will fail'
                })
            for sample in pre['samples']:
                processed_code = PostProcessor.solution_extract(sample)
                result.append({
                    'task_id': task['task_id'],
                    'prompt': pre['prompt'],
                    'test': task['test'],
                    'entry_point': task['entry_point'],
                    'completion': processed_code
                })
        return result, len(raw_problems)

    @staticmethod
    def map_task_id_for_test_case(predict_path, source_path):
        database = dict()
        raw_problems = Tools.load_tasks(source_path)
        for task_id in raw_problems.keys():
            database[raw_problems[task_id]['prompt']] = raw_problems[task_id]

        test_cases_by_task = defaultdict(list)
        predictions = Tools.load_jsonl(predict_path)
        for pre in predictions:
            task = database[pre['prompt']]
            for sample in pre['samples']:
                test_cases = PostProcessor.test_case_extract(sample, task['entry_point'])
                test_cases_by_task[task['task_id']].append(test_cases)
        return test_cases_by_task

    @staticmethod
    def solution_extract(content):
        for identifier in STOP_TOKEN:
            if identifier in content:
                content = content.split(identifier)[0]
        return content

    @staticmethod
    def test_case_extract(content, entry_point):
        def _truncate(content):
            for identifier in STOP_TOKEN:
                if identifier in content:
                    content = content.split(identifier)[0]
            return content.strip()

        split_by_assert = [f'assert {part}'.strip() for part in f'assert {content}'.split('assert ') if part and entry_point and (entry_point.strip() in part) and len(part.strip()) > 0]
        truncated_test_cases = [_truncate(i) for i in split_by_assert]
        checked_assertions = [i for i in truncated_test_cases if PostProcessor._check_test_case_validation(i)]
        return checked_assertions

    @staticmethod
    def _check_test_case_validation(test_case):
        if len(test_case.strip()) < 1:
            return False
        if 'assert' not in test_case:
            return False
        try:
            multi_line_test_case = test_case.replace("\n", "\n    ")
            assert_in_a_block = f'try:\n    {multi_line_test_case}\nexcept:\n    pass\n'
            compile(assert_in_a_block, '', 'exec')
            return True
        except Exception:
            return False
