import os
import json
import glob
import numpy as np
from argparse import ArgumentParser
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("models/Qwen2.5-VL-7B-Instruct/tokenizer.json")

def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="folder path of output files")
    return parser.parse_args()

def calculate_metrics(output_dir):
    # get all output files
    output_files = sorted(glob.glob(os.path.join(output_dir, "output_*.json")))
    
    if not output_files:
        print(f"cannot find output files in {output_dir}")
        return
    
    # for accumulating all data
    all_ious = []
    all_answer_entropy = []
    intersection_t = 0
    union_t = 0
    # read and process all files
    for file_path in output_files:
        with open(file_path, 'r', encoding='utf-8') as f:
            results = json.load(f)
            
        # process all items in each file
        for item in results:
            intersection = item['intersection']
            union = item['union']
            think_word_num = len(tokenizer.encode(item['think']))
            
            intersection_t += intersection
            union_t += union
            # calculate IoU of each item
            iou = intersection / union if union > 0 else 0
            all_ious.append({
                'image_id': item['image_id'],
                'iou': iou,
                "think_word_num": think_word_num
            })
            all_answer_entropy.append(item["answer_entropy"])
            
    
    # calculate gIoU
    gIoU = np.mean([item['iou'] for item in all_ious])
    # calculate cIoU
    cIoU = intersection_t / union_t
    avg_think_word_num = np.mean([item['think_word_num'] for item in all_ious])
    # print the results
    print(f"gIoU (average of per image IoU): {gIoU:.4f}")
    print(f"cIoU (total IoU): {cIoU:.4f}")
    print(f"Average number of think tokens: {avg_think_word_num:.2f}")
    print("mean all_answer_entropy:", sum(all_answer_entropy) / len(all_answer_entropy))
    # print("all_answer_entropy:", len(all_answer_entropy))
    

if __name__ == "__main__":
    args = parse_args()
    calculate_metrics(args.output_dir)
