# import pandas as pd


# df = pd.read_csv('your_file.csv', encoding='utf-8')

# for column_name, column_data in df.iteritems():
#     print(f"列名：{column_name}")
#     print(column_data)



import wandb
import pandas as pd
import json

api = wandb.Api()

project = 'role_playing/llm_classification_all_turns_gpt-4o'
runs = api.runs(project)



acc_total_list = []
cost_list = []
min_acc = 1.0
min_topic = ""
min_acc2 = 1.0
min_topic2 = ""
acc_split = {}

for i in range(2, 3):
    flag = 1
    acc_list = []
    for run in runs:
        summary = run.summary
        run_name = run.name
        topic = run_name.split('_')[4]
        type = run_name.split('_')[1]
        if topic not in acc_split:
            acc_split[topic] = {}
        if type not in acc_split[topic]:
            acc_split[topic][type] = []
        # if not run_name.endswith(str(i)):
        #     continue
        # if flag == 0:
        #     flag = 1
        #     continue
        flag = 0
        # if "VirtualReality" in run_name:
        #     continue
        if 'accuracy' in summary:
            acc_split[topic][type].append(summary['accuracy'])
            acc_total_list.append(summary['accuracy'])
            acc_list.append(summary['accuracy'])
            if summary['accuracy'] < min_acc:
                min_acc2 = min_acc
                min_topic2 = min_topic
                min_acc = summary['accuracy']
                min_topic = run_name
            else:
                if summary['accuracy'] < min_acc2:
                    min_acc2 = summary['accuracy']
                    min_topic2 = run_name
        if 'total_cost' in summary:
            cost_list.append(summary['total_cost'])

df = pd.DataFrame(columns=['topic', 'type', 'Child', 'Teen', 'College', 'Graduate', 'Expert'])
for topic, types in acc_split.items():
    for type, acc_list in types.items():
        df.loc[len(df)] = {'topic': topic, 'type': type, 'Child': acc_list[0], 'Teen': acc_list[1], 'College': acc_list[2], 'Graduate': acc_list[3], 'Expert': acc_list[4]}

df.to_csv('acc_split.csv', index=False)
#     print(len(acc_list))
#     print(min_acc, min_topic)
#     print(min_acc2, min_topic2)
# print(sum(cost_list))