from olym_gen.utils.sample_utils import list_check
import os
import json
import shutil

check_positive_proof_retrieve = lambda x: x.get('checked_proof', '')
check_negative_proof_retrieve = lambda x: x.get('checked_proof', '')
check_reference_proof_retrieve = lambda x: x.get('completed_proof', '')
check_mask_proof_retrieve = lambda x: x.get('checked_proof', '')
check_positive_result_retrieve = lambda x: x.get('check_result', {}).get('proof_correct', None)
check_negative_result_retrieve = lambda x: x.get('check_result', {}).get('proof_correct', None)
check_reference_result_retrieve = lambda x: x.get('proof_correct', None)
check_mask_result_retrieve = lambda x: x.get('check_result', {}).get('proof_correct', None)

problem_retrieve = lambda x: x.get('question', '')
proof_retrieve = lambda x: x['new_solution'] if 'new_solution' in x else (x['completed_proof'] if 'completed_proof' in x else '')
label_retrieve = lambda x: x.get('judgment', {}).get("correctness", '')

for model in ['gemini-2.5-flash', 'gpt-5-mini-2025-08-07', 'deepseek-reasoner', 'deepseek-r1-0528']:
    for file in ['putnam', 'usamo', 'olympiad-bench']:
        path = f'sample_by_Q/{file}/{model}/'

        check_positive_dir = path + 'check_positive/data'
        check_negative_dir = path + 'check_negative/data'
        check_reference_dir = path + 'check_reference/data'
        check_mask_dir = path + 'check_mask/data'


        methods = ['mask_replace', 'proof', 'rephrase']
        for method in methods:
            root_dir = 'labeled_result/' + path + f'{method}/'
            if not os.path.exists(root_dir):
                continue
            if model == 'deepseek-reasoner' and file == 'usamo':
                continue

            # check_positive_dir = './save/putnam_batch/check_positive/data'
            # check_negative_dir = './save/putnam_batch/check_negative/data'
            # check_reference_dir = './save/putnam_batch/masked_proof_completion/select_some_leading_replace/check_reference/data'

            check_positive = list_check(check_dir=check_positive_dir,problem_retrieve=problem_retrieve, proof_retrieve=check_positive_proof_retrieve, result_retrieve=check_positive_result_retrieve)
            # check_positive={}

            check_negative = list_check(check_dir=check_negative_dir,problem_retrieve=problem_retrieve, proof_retrieve=check_negative_proof_retrieve, result_retrieve=check_negative_result_retrieve)
            # check_negative = {}

            # check_reference = list_check(check_dir=check_reference_dir,problem_retrieve=problem_retrieve, proof_retrieve=check_reference_proof_retrieve, result_retrieve=check_reference_result_retrieve)
            check_reference={}
            
            check_mask = list_check(check_dir=check_mask_dir,problem_retrieve=problem_retrieve, proof_retrieve=check_mask_proof_retrieve, result_retrieve=check_mask_result_retrieve)
            # check_mask={}

            check_total = check_positive | check_negative | check_mask | check_reference
            # check_total = check_mask

            # check_reference = list_check(check_dir=check_reference_dir,problem_retrieve=problem_retrieve, proof_retrieve=check_reference_proof_retrieve, result_retrieve=lambda x: x)
            # check_mask = list_check(check_dir=check_mask_dir,problem_retrieve=problem_retrieve, proof_retrieve=check_mask_proof_retrieve, result_retrieve=lambda x: x)

            # dest_dir = './temp/merged_mask/'

            results = []
            for root, dirs, files in os.walk(root_dir):
                for file in files:
                    if not file.endswith('.json'):
                        continue
                    if file == 'correspondence.json':
                        continue
                    file_path = os.path.join(root, file)
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            data = json.load(f)
                            # Process the data here
                        problem = problem_retrieve(data)
                        proof = proof_retrieve(data)
                        label = label_retrieve(data)
                        check_result = check_total.get(f'{problem}___{proof}', [])
                        if len(check_result) == 0:
                            print(f"No check result for file: {file_path}")
                            continue
                        correct_count = check_result.count(True)
                        incorrect_count = check_result.count(False)
                        results.append((file_path, label, correct_count, incorrect_count))
                        # check_reference_result = check_reference.get(f'{problem}___{proof}', [])
                        # check_mask_result = check_mask.get(f'{problem}___{proof}', [])
                        # data['check_reference_result'] = check_reference_result
                        # data['check_mask_result'] = check_mask_result
                        # if not os.path.exists(dest_dir):
                        #     os.makedirs(dest_dir)
                        # with open(os.path.join(dest_dir, file), 'w', encoding='utf-8') as f:
                        #     json.dump(data, f, ensure_ascii=False, indent=4)
                    except json.JSONDecodeError:
                        print(f"Error decoding JSON in file: {file_path}")
                    except Exception as e:
                        print(f"Error processing file {file_path}: {e}")

            P3 = []
            P0 = []

            correct_files = []
            incorrect_files = []

            for result in results:
                file, label, correct_count, incorrect_count = result
                if correct_count == 3 and incorrect_count == 0:
                    P3.append(result)
                if label == 'correct':
                    correct_files.append(file)
                if correct_count == 0 and incorrect_count == 3:
                    P0.append(result)
                if label == 'incorrect':
                    incorrect_files.append(file)

            print(f'P3: {len(P3)}, P0: {len(P0)}, Total: {len(results)}')

            for i in range(4):
                for j in range(4):
                    if i + j > 3:
                        continue
                    count = sum(1 for result in results if result[2] == i and result[3] == j)
                    if count > 0:
                        print(f'Correct: {i}, Incorrect: {j}, Count: {count}')

            # for i in range(4):
            #     for j in range(4):
            #         if i + j > 3:
            #             continue
            #         count = sum(1 for key, value in check_total.items() if value.count(True) == i and value.count(False) == j)
            #         if count > 0:
            #             print(f'Correct: {i}, Incorrect: {j}, Count: {count}')

            # Label Accuracy for P3
            correct_labels = sum(1 for result in P3 if result[1] == 'correct')
            accuracy = correct_labels / len(P3) if P3 else 0
            print(f'P3 Label Accuracy: {accuracy:.2f}')

            # Label Accuracy for P0
            correct_labels = sum(1 for result in P0 if result[1] == 'incorrect')
            accuracy = correct_labels / len(P0) if P0 else 0
            print(f'P0 Label Accuracy: {accuracy:.2f}')

            # print(results)

            os.makedirs(os.path.join('validset', 'positive', root_dir), exist_ok=True)
            os.makedirs(os.path.join('validset', 'negative', root_dir), exist_ok=True)

            # res = input('Copy P3?')
            res='y'
            if res == 'y':
                for file in correct_files:
                    new_file = os.path.join('validset','positive', file)
                    shutil.copy(file, new_file)

            # res = input('Copy P0?')
            res='y'
            if res == 'y':
                for file in incorrect_files:
                    new_file = os.path.join('validset','negative', file)
                    shutil.copy(file, new_file)
