
import json, argparse
import os, math, glob
from pathlib import Path
import re, copy
import numpy as np

def load_data(input_file):
    with open(input_file, 'r') as f:
        return json.load(f)


def save_data(output_file, data):
    os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else '.', exist_ok=True)
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
    print(output_file)


def get_not_approved(pat):
    data=load_data(pat)
    invalid_data=[]
    valid_data=[]
    for item in data:
        if item["solver_status"]=="success":
            if item["final_answer"]=="A":
                valid_data.append(item)
            else:
                invalid_data.append(item)
        else:
            invalid_data.append(item)

    save_data(pat.replace('.json','_invalid.json'),invalid_data)
    save_data(pat.replace('.json','_valid.json'),valid_data)


def valid_merge(folder_path, output_pat, dataset):
    file_list=[]
    for root, _, files in os.walk(folder_path):
        for file in files:

            if '1_2_' in file:
                continue
            if (file.lower().endswith("_fmt_valid.json")) and (dataset in file):
                file_list.append(os.path.join(root, file))
    new_data=[]
    for file_name in file_list:
        print(file_name, len(load_data(file_name)))
        for item in load_data(file_name):
            new_data.append(item)
    save_data(output_pat, new_data)

    if 'maskid' in item:
        valid_idx=[item['id']+str(item['maskid']) for item in new_data]
    else:
        valid_idx=[item['id'] for item in new_data]
    ori_invalid_pat=f'{folder_path}/{dataset}_invalid.json'
    ori_invalid_data=load_data(ori_invalid_pat)
    new_data_invalid=[]
    for item in ori_invalid_data:
        if 'maskid' in item:
            cu_id=item['id']+str(item['maskid'])
        else:
            cu_id=item['id']
        if cu_id not in valid_idx:
            new_data_invalid.append(item)
    save_data(output_pat.replace('_valid','_invalid'), new_data_invalid)




    

def get_not_fact(pat):
    data=load_data(pat)
    invalid_data=[]
    valid_data=[]
    for item in data:
        if 'False' in item["llm_judge"]:
            invalid_data.append(item)       
        else:
            valid_data.append(item)

    save_data(pat.replace('.json','_nofact.json'),invalid_data)
    save_data(pat.replace('.json','_fact.json'),valid_data)

def cal_std_cot(data):
    y_pred, y_true=[],[]
    for item in data:
        ans=item['gpt_res'].lower()
        gt=item['answer']
        if 'conclusion:\na' in ans or 'conclusion:  \na' in ans or 'conclusion: a' in ans or 'a)' in ans:
            y_pred.append(1)
        elif 'conclusion:\nb' in ans or 'conclusion:  \nb' in ans or 'conclusion: b' in ans or 'b)' in ans:
            y_pred.append(0)
        else:
            if gt=="A":
                y_pred.append(0)
            elif gt=="B":
                y_pred.append(1)

        if gt=="A":
            y_true.append(1)
        elif gt=="B":
            y_true.append(0)
    
    f1, precision, recall, acc=f1_score(y_true, y_pred)
    print(f' acc= {acc*100:.2f}, recall= {recall*100:.2f}, f1= {f1*100:.2f} *** {len(y_true),len(y_pred)}')


def f1_score(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    TP = np.sum((y_true == 1) & (y_pred == 1))
    FP = np.sum((y_true == 0) & (y_pred == 1))
    FN = np.sum((y_true == 1) & (y_pred == 0))
    TN = np.sum((y_true == 0) & (y_pred == 0))

    acc=(TP+TN)/(TP+FP+FN+TN)
    precision = TP / (TP + FP + 1e-9)
    recall = TP / (TP + FN + 1e-9)

    f1 = 2 * precision * recall / (precision + recall + 1e-9)
    return f1, precision, recall, acc




def all_acc_yg_f1(folder_path, dataset_name, model_name):
    
    k=4
    if k==5:
        k_name='1_2_3_4_5'
    elif k==4:
        k_name='1_2_3_4'
    elif k==3:
        k_name='1_2_3'
    elif k==2:
        k_name='1_2'
    elif k==1:
        k_name='one'


    sl2nl_file_list=glob.glob(f'{folder_path}/sl2nl/*')
    solver_file_list=glob.glob(f'{folder_path}/solver/*')

    A_file_list=[f'{folder_path}/solver/{dataset_name}_valid.json']
    B_file_list=[]
    cot_file_list=[]

    for file in sl2nl_file_list:
        if file.endswith("_fact.json") and f'{k_name}_fmt' in file:
            A_file_list.append(file)
        elif file.endswith("_nofact.json") and f'_nofact_{k}' in file  and f'{k_name}_fmt' in file:
            B_file_list.append(file)

    for file in solver_file_list:
        if file.endswith(f"{k_name}_fmt_invalid.json") :
            cot_file_list.append(file)
        elif file.endswith("_invalid.json") and '_nofact_' in file and f'{k_name}_fmt' in file:
            cot_file_list.append(file)

    cot_pat=f'results/0baseline/{model_name}/{dataset_name}_cot.json'
    std_pat=f'results/0baseline/{model_name}/{dataset_name}.json'
    give_pat=cot_pat.replace('cot','ours') ##
    give_data=[] ##

    print('*'*60)
    print(std_pat)
    y_true, y_pred=[],[]
    std_data=load_data(std_pat)
    cal_std_cot(std_data)
    print(cot_pat)
    y_true, y_pred=[],[]
    cot_data=load_data(cot_pat)
    cal_std_cot(cot_data)
    print('*'*60)


    y_true, y_pred=[],[]
    for file in A_file_list:
        print("Predicted A:", file, len(load_data(file)))
        data=load_data(file)
        for item in data:
            item['ours']='A' ##
            give_data.append(item) ##
            cu_id=item['id']
            if 'maskid' in item:
                cu_id=item['id']+str(item['maskid'])
            y_pred.append(1)
            if item['answer']=="A":
                y_true.append(1)
            elif item['answer']=="B":
                y_true.append(0)

    for file in B_file_list:
        print("Predicted B:", file, len(load_data(file)))
        data=load_data(file)
        for item in data:
            item['ours']='B' ##
            give_data.append(item) ##
            cu_id=item['id']
            if 'maskid' in item:
                cu_id=item['id']+str(item['maskid'])
            y_pred.append(0)
            if item['answer']=="A":
                y_true.append(1)
            elif item['answer']=="B":
                y_true.append(0)

    yaode_cot_data=[]
    for file in cot_file_list:
        print("Predicted CoT:", file, len(load_data(file)))
        data=load_data(file)
        for item in data:
            cu_id=item['id']
            if 'maskid' in item:
                cu_id=item['id']+str(item['maskid']) 
            yaode_cot_data.append(cu_id)

    for item in cot_data:
        cu_id=item['id']
        if 'maskid' in item:
            cu_id=item['id']+str(item['maskid']) 
        if cu_id in yaode_cot_data:
            ans=item['gpt_res'].lower()
            gt=item['answer']
            if 'conclusion:\na' in ans or 'conclusion:  \na' in ans or 'conclusion: a' in ans or 'a)' in ans:
                y_pred.append(1)
                item['ours']='A' ##
                give_data.append(item) ##
            elif 'conclusion:\nb' in ans or 'conclusion:  \nb' in ans or 'conclusion: b' in ans or 'b)' in ans:
                y_pred.append(0)
                item['ours']='B' ##
                give_data.append(item) ##
            else:
                if gt=="A":
                    y_pred.append(0)
                    item['ours']='B' ##
                    give_data.append(item) ##
                elif gt=="B":
                    y_pred.append(1)
                    item['ours']='A' ##
                    give_data.append(item) ##

            if gt=="A":
                y_true.append(1)
            elif gt=="B":
                y_true.append(0)
    f1, precision, recall, acc=f1_score(y_true, y_pred)
    print(f' acc= {acc*100:.2f}, recall= {recall*100:.2f}, f1= {f1*100:.2f} *** {len(y_true),len(y_pred)}')

    y_true, y_pred=[],[]
    save_data(give_pat, give_data)
    print('give data',len(give_data))
    for item in give_data:
        ans=item['ours']
        gt=item['answer']
        if gt=="A":
            y_true.append(1)
        elif gt=="B":
            y_true.append(0)
        if ans=="A":
            y_pred.append(1)
        elif ans=="B":
            y_pred.append(0)
    f1, precision, recall, acc=f1_score(y_true, y_pred)
    print(f' acc= {acc*100:.2f}, recall= {recall*100:.2f}, f1= {f1*100:.2f} *** {len(y_true),len(y_pred)}')






def main():
    parser = argparse.ArgumentParser(description='tools use')
    parser.add_argument('--input_file', type=str, default='', help='Input JSON file path')
    parser.add_argument('--output_file', type=str, default='', help='Output JSON file path')
    parser.add_argument('--dataset', type=str, default='')
    parser.add_argument('--model_name', type=str, default='')

    parser.add_argument('--task_type', type=str, choices=['get_not_approved','valid_merge','all_acc_yg_f1'], default='all_acc_yg_f1', help='')


    args = parser.parse_args()
    task_type=args.task_type

    if task_type=='get_not_approved':
        get_not_approved(args.input_file)

    elif task_type=='valid_merge':
        valid_merge(args.input_file, args.output_file, args.dataset)

    elif task_type=='all_acc_yg_f1':
        all_acc_yg_f1(args.input_file, args.dataset, args.model_name)




if __name__ == '__main__':
    main()

