# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

import json
import os
import argparse
from collections import defaultdict


def calculate_wiscore(consistency, realism, aesthetic_quality):
    return 0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality


def cal_culture(file_path):
    all_scores = []
    total_objects = 0
    has_9_9 = False
    
    with open(file_path, 'r') as file:
        for line in file:
            total_objects += 1
            data = json.loads(line)
            if 9.9 in [data['consistency'], data['realism'], data['aesthetic_quality']]:
                has_9_9 = True
            wiscore = calculate_wiscore(data['consistency'], data['realism'], data['aesthetic_quality'])
            all_scores.append(wiscore)
    
    if has_9_9 or total_objects < 400:
        print(f"Skipping file {file_path}: Contains 9.9 or has less than 400 objects.")
        return None
    
    total_score = sum(all_scores)
    avg_score = total_score / (len(all_scores)*2) if len(all_scores) > 0 else 0
    
    score = {
        'total': total_score,
        'average': avg_score
    }

    print(f"  Cultural - Total: {score['total']:.2f}, Average: {score['average']:.2f}")

    return avg_score


def cal_space_time(file_path):
    categories = defaultdict(list)
    total_objects = 0
    has_9_9 = False
    
    with open(file_path, 'r') as file:
        for line in file:
            total_objects += 1
            data = json.loads(line)
            if 9.9 in [data['consistency'], data['realism'], data['aesthetic_quality']]:
                has_9_9 = True
            subcategory = data['Subcategory']
            wiscore = calculate_wiscore(data['consistency'], data['realism'], data['aesthetic_quality'])
            if subcategory in ['Longitudinal time', 'Horizontal time']:
                categories['Time'].append(wiscore)
            else:
                categories['Space'].append(wiscore)
    
    if has_9_9 or total_objects < 300:
        print(f"Skipping file {file_path}: Contains 9.9 or has less than 400 objects.")
        return None
    
    total_scores = {category: sum(scores) for category, scores in categories.items()}
    avg_scores = {category: sum(scores) / (len(scores) * 2 )if len(scores) > 0 else 0 for category, scores in categories.items()}
    
    scores = {
        'total': total_scores,
        'average': avg_scores
    }

    print(f"  Time - Total: {scores['total'].get('Time', 0):.2f}, Average: {scores['average'].get('Time', 0):.2f}")
    print(f"  Space - Total: {scores['total'].get('Space', 0):.2f}, Average: {scores['average'].get('Space', 0):.2f}")

    return avg_scores


def cal_science(file_path):
    categories = defaultdict(list)
    total_objects = 0
    has_9_9 = False
    
    with open(file_path, 'r') as file:
        for line in file:
            total_objects += 1
            data = json.loads(line)
            if 9.9 in [data['consistency'], data['realism'], data['aesthetic_quality']]:
                has_9_9 = True
            
            prompt_id = data.get('prompt_id', 0)
            if 701 <= prompt_id <= 800:
                category = 'Biology'
            elif 801 <= prompt_id <= 900:
                category = 'Physics'
            elif 901 <= prompt_id <= 1000:
                category = 'Chemistry'
            else:
                category = "?"
            
            wiscore = calculate_wiscore(data['consistency'], data['realism'], data['aesthetic_quality'])
            categories[category].append(wiscore)
    
    if has_9_9 or total_objects < 300: 
        print(f"Skipping file {file_path}: Contains 9.9 or has less than 300 objects.")
        return None
    
    total_scores = {category: sum(scores) for category, scores in categories.items()}
    avg_scores = {category: sum(scores) / (len(scores)*2) if len(scores) > 0 else 0 for category, scores in categories.items()}

    scores = {
        'total': total_scores,
        'average': avg_scores
    }

    for category in ['Biology', 'Physics', 'Chemistry']:
        print(f"  {category} - Total: {scores['total'].get(category, 0):.2f}, Average: {scores['average'].get(category, 0):.2f}")
    
    return avg_scores


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Image Quality Assessment Tool')
    parser.add_argument('--output_dir', required=True,
                        help='Path to the output directory')
    args = parser.parse_args()

    avg_score = dict()

    score = cal_culture(
        os.path.join(args.output_dir, "cultural_common_sense_scores.jsonl")
    )
    avg_score['Cultural'] = score

    scores = cal_space_time(
        os.path.join(args.output_dir, "spatio-temporal_reasoning_scores.jsonl")
    )
    avg_score.update(scores)

    scores = cal_science(
        os.path.join(args.output_dir, "natural_science_scores.jsonl")
    )
    avg_score.update(scores)

    avg_all = sum(avg_score.values()) / len(avg_score)

    avg_score['Overall'] = avg_all
    keys = ""
    values = ""
    for k, v in avg_score.items():
        keys += f"{k} "
        values += f"{v:.2f} "
    print(keys)
    print(values)

    writer = open(os.path.join(args.output_dir, "results.txt"), 'w')
    print(f"write results to file {os.path.join(args.output_dir, 'results.txt')}")
    writer.write(keys + "\n")
    writer.write(values + "\n")
    writer.close()