import os
import json
import argparse
import pandas as pd


def get_row(df):
    row_output = {}
    types = ['Recognition', 'Understanding', 'Grounding', 'Reasoning']

    for t in types:
        t_num = df.loc[:, t].sum()
        t_ratio = round(100 * t_num / len(df), args.round)
        row_output[t + " ✔"] = f"{t_num}  ({t_ratio}%)"

    four_correct_num = (df.loc[:, types].sum(axis=1) == 4).sum()
    four_correct_ratio = round(100 * four_correct_num / len(df), args.round)
    row_output["Four ✔"] = f"{four_correct_num}  ({four_correct_ratio}%)"

    four_wrong_num = (df.loc[:, types].sum(axis=1) == 0).sum()
    four_wrong_ratio = round(100 * four_wrong_num / len(df), args.round)
    row_output["Four ❌"] = f"{four_wrong_num}  ({four_wrong_ratio}%)"
    row_output["All"] = len(df)

    return row_output


def verify(args):
    csv_file = f"evaluate/comparison/cmp_{args.sol_name}.csv"
    output_file = f"evaluate/analysis/stat_{args.sol_name}.xlsx"
    df = pd.read_csv(csv_file)

    intersect_file = "intersect/intersect.json"
    with open(intersect_file, "r") as f:
        intersect_data = json.load(f)
    category_data = {**intersect_data['Type'], **intersect_data['Domain']}

    output = {}
    for name, content in category_data.items():
        c_diagram_indices = eval(content["List"])
        output[name] = get_row(df[df['idx'].isin(c_diagram_indices)])
    output["All Categories"] = get_row(df)

    output_df = pd.DataFrame.from_dict(output, orient='index')
    output_df.insert(0, 'Category', output_df.index)
    output_df.to_csv(output_file, index=False)

    with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
        output_df.to_excel(writer, index=False, sheet_name='Sheet1')
        workbook = writer.book
        worksheet = writer.sheets['Sheet1']
        cell_format = workbook.add_format({'font_name': 'Times New Roman', 'font_size': 12})
        for col in range(output_df.shape[1]):
            max_len = max(
                output_df.iloc[:, col].astype(str).map(len).max(),
                len(output_df.columns[col])
            )
            worksheet.set_column(col, col, max_len + 2, cell_format)

        header_format = workbook.add_format({
            'font_name': 'Times New Roman',
            'font_size': 12,
            'bold': True,
            'align': 'center',
            'valign': 'vcenter'
        })
        for col_num, header in enumerate(output_df.columns.values):
            worksheet.write(0, col_num, header, header_format)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--sol_name', type=str, default='empty_v2')
    parser.add_argument('--round', type=int, default=2)
    args = parser.parse_args()

    verify(args)

