from utils.report_generator import judge_accuracy
import pandas as pd
import argparse
import os
import re

def main():
    parser = argparse.ArgumentParser(description="evaluate model final answer")
    parser.add_argument('--data-input', type=str, required=True, help='Data input file', action='extend', nargs='+')
    parser.add_argument('--remove-additional', action='store_true', help='Remove additional info columns')
    parser.add_argument('--accept-qid-file', type=str, help='File with a list of QuestionID to accept')
    args = parser.parse_args()

    for data_input in args.data_input:
        if not os.path.exists(data_input):
            print(f"E: {data_input} does not exist")
            exit(1)

    if args.accept_qid_file:
        with open(args.accept_qid_file, 'r') as f:
            qids = [int(s.strip()) for s in f.readlines() if s.strip()]

    results = [['Type', 'Img Rev', 'Model', 'Img Type', 'w/ GT', 'Size', 'GT_Type', 'Accuracy', 'Support']]
    for data_input in args.data_input:
        # load the data
        data = pd.read_csv(data_input)
        # for self-correction task, we need to drop the rows where the human judge indicate removal instead of correction
        if 'correction_human' in data.columns:
            # drop rows where correction_human is an empty string or NaN
            data = data[data['correction_human'].notna() & (data['correction_human'] != '')]
        # filter the data if qids are provided
        if args.accept_qid_file:
            data = data[data['qid'].isin(qids)]
            assert len(data) == len(qids), f"Not all qids are present in {data_input}"

        basename = os.path.basename(data_input)
        gt_type = 'Full'
        model_out_type = 'Extracted'
        if basename.startswith('judge_short_'):
            basename = basename.replace('judge_short_', 'judge_')
            gt_type = 'Short'
        if basename.startswith('judge_raw_'):
            basename = basename.replace('judge_raw_', 'judge_')
            model_out_type = 'Raw'
        # Extract the relevant information from the filename
        experiment_type = 'ground truth only' if 'ref_only' in basename else f"w/ adversarial regions"
        answer_provided = 'N' if 'exp_first' in basename else 'Y'
        filename = os.path.basename(basename)
        prompt_ver = filename.split('_')[1]
        prompt_ver.replace('explain-', '')
        _prompt_category = 'CoT Style' if '-step' in prompt_ver else 'Paragraph Like'
        _prompt_additional_req = '+ Extra Conditions' if '-cond' in prompt_ver else ''
        _rel_icl_example = '(ref-example)' if '-icl-ref' in prompt_ver else ''
        prompt_ver = f"{_prompt_category} {_prompt_additional_req}".strip()
        experiment_type = f"{experiment_type} {_rel_icl_example}".strip()
        bbox_ver = filename.split('_')[3] if '-no-bbox' not in basename else 'No Box'
        model_str = filename.split('_')[-1].rsplit('.', 1)[0]
        model_size = re.findall(r'(\d+)b', model_str, re.IGNORECASE)
        if model_size:
            model_size = int(model_size[0])
        else:
            model_size = float('inf')

        if len(data.judge.unique()) != 2:
            print(f"File: {data_input}")
            print("Warning: Judge column should only contain two values")
            print(f"Got: {','.join(list(data.judge.unique()))}")

        # calculate the accuracy
        accuracy = judge_accuracy(data)
        results.append([prompt_ver, bbox_ver, model_str, experiment_type,
                        answer_provided, model_size, gt_type, round(accuracy * 100, 2), len(data)])

    results = pd.DataFrame(results[1:], columns=results[0])
    if args.remove_additional:
        # remove additional info columns
        results = results[['GT_Type', 'Model', 'Size', 'Accuracy']]
        # sort the results by model and size
        results = results.sort_values(by=['GT_Type', 'Size', 'Accuracy'], ascending=[True, True, True])
    else:
        # sort the results by type and then img type
        results = results.sort_values(by=['GT_Type', 'Img Rev', 'Size', 'Type', 'Img Type', 'w/ GT'],
                                      ascending=[True, True, True, True, True, True])
    # print the results in a table format
    print(results.to_markdown(index=False))


if __name__ == '__main__':
    main()
