import json
import sys

# Input your generative file here
file_list = []

data = []
for file in file_list:
    with open(file, 'r') as f:
        for line in f:
            data.append(json.loads(line))

# modify here
with open("<your path to benchmark.jsonl>", "r") as f:
    gt_data = [json.loads(line) for line in f]

for item in data:
    for gt_item in gt_data:
        if item['id'] == gt_item['id']:
            item['Category'] = gt_item['Category']
            break

Sum_category = {
    "Style Transer": ["Art Style Transfer", "Scene Attribute Transfer", "Photo Variation", "Portrait Variation"],
    "Progressive": ["Animation Images", "Animation Text-Image", "Attribute Transfer"],
    "3D Scene": ["3D_object", "multi-Perspective Scene Generation"],
    "Image Decomposition": ["Realistic Object Detection", "Imaginary Object Detection"],
    "Image-Text Complementation": ["HowTo", "Scientific"],
    "Temporal Prediction": ["Prediction", "painting"],
    "Visual Story Telling": ["visual storytelling text-image", "visual storytelling image", "visual storytelling text"],
    "VQA": ["Object_VQA", "Historical"],
}

def analyze_data(data):
    category_statistics = {category: {"structure": 0, "TextImGraph": [], "DSG": [], "Overall": [], "total": 0} for category in Sum_category.keys()}
    for item in data:
        for category, subcategories in Sum_category.items():
            if item['Category'] in subcategories:
                category_statistics[category]['total'] += 1
                if 'result' not in item:
                    category_statistics[category]['TextImGraph'].append(1)
                    category_statistics[category]['DSG'].append(0)
                    category_statistics[category]['Overall'].append(1)
                    continue
                if item['result'].get('structure'):
                    category_statistics[category]['structure'] += 1
                    
                    if item['result'].get('TextImGraph'):
                        textimgraph_data = item['result'].get('TextImGraph', [])
                        if textimgraph_data:
                            try:
                                textimgraph = [ii['VQA_judge_Score']['Judge'] for ii in textimgraph_data if 'VQA_judge_Score' in ii and 'Judge' in ii['VQA_judge_Score']]
                                if textimgraph:
                                    category_statistics[category]['TextImGraph'].append(sum(textimgraph)/len(textimgraph))
                                else:
                                    category_statistics[category]['TextImGraph'].append(1)
                            except Exception as e:
                                category_statistics[category]['TextImGraph'].append(1)
                                print(f"Error processing TextImGraph for item {item.get('id', 'unknown')}: {e}")
                else:
                    category_statistics[category]['TextImGraph'].append(1)

                if item['result'].get('DSG'):
                    # print(item['id'])
                    list_DSG = []
                    for ii in item['result']['DSG']:
                        if 'judge' in ii and ii['judge']:
                            if 'Judge' in ii['judge']:  
                                if ii['judge']['Judge'].lower() == 'yes':
                                    list_DSG.append(1)
                    DSG = sum(list_DSG)
                    total = len(item['result']['DSG'])
                    if total != 0:
                        category_statistics[category]['DSG'].append(DSG/total)
                    else:
                        category_statistics[category]['DSG'].append(0)
                else:
                    category_statistics[category]['DSG'].append(0)
                
                if item['result'].get('general_judge_W_GT'):
                    if 'overall_score' in item['result']['general_judge_W_GT']:
                        if 'output' not in item or not item['output'] or item['output'][0]['content'] == '':
                            category_statistics[category]['Overall'].append(1)
                        else:
                            category_statistics[category]['Overall'].append(item['result']['general_judge_W_GT']['overall_score'])
                    else:
                        category_statistics[category]['Overall'].append(1)
                else:
                    category_statistics[category]['Overall'].append(1)

    return category_statistics

def print_statistics(category_statistics):
    print("Category     Structure     DSG     TextImGraph     Overall")
    print("-" * 60)
    for category, stats in category_statistics.items():
        structure = stats['structure'] / stats['total'] if stats['total'] > 0 else 0
        dsg = sum(stats['DSG']) / len(stats['DSG']) if stats['DSG'] else 0
        textimgraph = sum(stats['TextImGraph']) / len(stats['TextImGraph']) if stats['TextImGraph'] else 0
        overall = sum(stats['Overall']) / len(stats['Overall']) if stats['Overall'] else 0
        print(f"{category.replace(' ','_'):<13} {structure:.4f}     {dsg:.4f}     {textimgraph:.4f}     {overall:.4f}")

category_statistics = analyze_data(data)
print_statistics(category_statistics)