import pandas as pd
import numpy as np
import os
import re

def extract_boxed_content(text: str) -> str:
    """
    Extracts answers in \\boxed{}.
    """
    
    depth = 0
    start_pos = text.rfind(r"{")
    end_pos = -1
    if start_pos != -1:
        content = text[start_pos + len(r"{") :]
        for i, char in enumerate(content):
            if char == "{":
                depth += 1
            elif char == "}":
                depth -= 1

            if depth == -1:  # exit
                end_pos = i
                break

    if end_pos != -1:
        return content[:end_pos].strip()

    return "None"
def extract_gdpr_chapter_number(chapter):
    # 从章节名称中提取阿拉伯数字
    if chapter.startswith('Chapter '):
        parts = chapter.split()
        try:
            return int(parts[1])
        except (IndexError, ValueError):
            pass
    return float('inf')  # 对于没有数字的章节，放在最后



def calculate_distribution(data_tmp):
    distribution_list = [0]*12
    # for k in range(1,14):
    #     distribution_dict[k] = 0
    miss_count = 0
    for i in range(len(data_tmp)):
        response = data_tmp.iloc[i]['responses'][0]
        # print(response)
        box_content = extract_boxed_content(response)
        chapter_id = re.findall(r'\d+', box_content)
        if chapter_id:
            chapter_id = int(chapter_id[0])
        else:
            chapter_id = -1

        # chapter_id = extract_gdpr_chapter_number(box_content)
        # parts = box_content.split(':')[0].split()
        # if len(parts) > 1 and parts[1].strip() in ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X', 'XI', 'XII', 'XIII']:
        #     # print(roman_to_int(parts[1].strip()))
        #     distribution_list[roman_to_int(parts[1].strip())] += 1
        # else:
        #     miss_count += 1
        if chapter_id >= 1 and chapter_id <= 11:
            distribution_list[chapter_id] += 1
        else:
            # print(response[-100:])
            miss_count += 1

        
    distribution_list.pop(0)
    def normalize(list_obj):
        sum_=0
        for i_ in range(len(list_obj)):
            sum_ += list_obj[i_]

        return [j / sum_ for j in list_obj], sum_
    distribution_list, sum_ = normalize(distribution_list)
    return distribution_list, sum_, miss_count

def list_files(directory):
    try:
        # List all files in the specified directory
        file_names = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
        return file_names
    except Exception as e:
        print(f"An error occurred: {e}")
        return []

files_ai_act_chapter = list_files('.model_gen_out_put_safety_data/asso_with_gdpr_chapter')
files_ai_act_chapter

distribution_array = []
for file in files_ai_act_chapter:
    data_tmp = pd.read_parquet(f'.model_gen_out_put_safety_data/asso_with_gdpr_chapter/{file}')
    # print(data_tmp['responses'][0][0])
    # print(data_tmp.iloc[0]['responses'][0])

    distribution_list, sum_, miss_count = calculate_distribution(data_tmp)
    # break
    print(file, sum_, miss_count)
    distribution_array.append(distribution_list)

np_distr = np.array(distribution_array)


print(np_distr.T  * 100)