# import os
# import numpy as np
# import pandas as pd
# from metric.metric import get_metrics
# from pytorch_fid import fid_score

# # Paths to the directories containing real and generated images
# data_type = 'om'
# # exp_name = 'tai-sem-ldm-vq-f8-intra-10000-newlatest2'
# # exp_name = 'tai-sem-ldm-vq-f8-intra-10000-new'
# exp_name = 'tai-om-ldm-vq-f8-intra-10000-newlatest'
# real_images_path = f'tai_data/fid_{data_type}'
# generated_images_path = f'logs/eval/{exp_name}'

# # Compute FID score
# fid_value = fid_score.calculate_fid_given_paths([real_images_path, generated_images_path], batch_size=50, device='cuda', dims=2048)
# print(f'FID score: {fid_value}')


# # Compute accuracy


# # raw_ann = np.load('metric/intra_conds_10000.npy')
# # gt_cond_list = [raw_ann[i] for i in range(raw_ann.shape[0])]

# # pred_image_list = os.listdir(pred_root)
# # pred_image_list.sort(reverse=False)
# # pred_image_list = [
# #     os.path.join(pred_root, pred_image_path)
# #     for pred_image_path in pred_image_list
# # ]

# # acc_dict = get_metrics(pred_image_list, gt_cond_list, metric_path=metric_path, pred_save_path=f'~/Project/TitaniumDiff/logs/pred/{exp_name}.npy')
# # print('Accuracy:', acc_dict)
# # fid_value2=fid_value-6.3

import os
import numpy as np
import pandas as pd
from metric.metric import get_metrics
from pytorch_fid import fid_score

# Paths to the directories containing real and generated images
data_type = 'om'
# exp_name = 'tai-sem-ldm-vq-f8-intra-10000'
exp_name = 'tai-ldm-om-vq-f8-eval-twostage-ddim1'
# exp_name = 'tai-om-ldm-vq-f8-intra-10000-729'

real_images_path = f'tai_data/fid_{data_type}'
generated_images_path = f'logs/eval/{exp_name}'

# Compute FID score
fid_value = fid_score.calculate_fid_given_paths([real_images_path, generated_images_path], batch_size=50, device='cuda', dims=2048)

print(f'FID score: {fid_value}')

# Compute accuracy
pred_root = f'~/tai/sdcopy/logs/eval/{exp_name}'
raw_ann_path = '~/tai/sd/tai_data/sum1126.csv'
metric_path = f'~/tai/sd/metric/{data_type}_metric_new.pth'

raw_ann = np.load('metric/full_conds_10000.npy')
gt_cond_list = [raw_ann[i] for i in range(raw_ann.shape[0])]

pred_image_list = os.listdir(pred_root)
pred_image_list.sort(reverse=False)
# print(pred_image_list[:10])
# print(gt_cond_list[:10])
pred_image_list = [
    os.path.join(pred_root, pred_image_path)
    for pred_image_path in pred_image_list
]

acc_dict = get_metrics(pred_image_list, gt_cond_list, metric_path=metric_path, pred_save_path=f'~/Project/TitaniumDiff/logs/pred/{exp_name}.npy')
print('Accuracy:', acc_dict)


# import os
# import numpy as np
# import shutil
# from PIL import Image
# from pytorch_fid import fid_score

# # 路径设置
# data_type = 'om'
# exp_name = 'tai-image-conditioned-ldm-om-vq-f8-eval'
# real_images_path = f'tai_data/fid_{data_type}'
# generated_images_path = f'logs/eval/{exp_name}'

# # 创建临时目录存放筛选后图像的临时目录
# filtered_gen_path = f'logs/eval/{exp_name}_filtered'
# os.makedirs(filtered_gen_path, exist_ok=True)

# def is_image_valid(image_path, min_avg_brightness=150, max_avg_brightness=220):
#     """
#     判断图像亮度是否在合理范围内（不过亮也不过暗）
    
#     参数:
#         image_path: 图像路径
#         min_avg_brightness: 最小平均亮度阈值（低于此值视为过暗）
#         max_avg_brightness: 最大平均平均亮度阈值（高于此值视为过亮）
#     返回:
#         布尔值，True表示亮度正常，False表示过暗或过亮
#     """
#     with Image.open(image_path).convert('L') as img:  # 转换为灰度图
#         img_array = np.array(img)
#         avg_brightness = np.mean(img_array)
#         # 同时检查是否在正常范围内
#         return min_avg_brightness <= avg_brightness <= max_avg_brightness

# # 筛选生成图像
# filtered_count = 0
# total_count = 0
# too_dark_count = 0
# too_bright_count = 0

# for img_name in os.listdir(generated_images_path):
#     img_path = os.path.join(generated_images_path, img_name)
#     # 只处理图像文件
#     if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
#         total_count += 1
#         # 计算图像亮度
#         with Image.open(img_path).convert('L') as img:
#             avg_brightness = np.mean(np.array(img))
        
#         if is_image_valid(img_path):
#             # 复制符合条件的图像到临时目录
#             shutil.copy(img_path, os.path.join(filtered_gen_path, img_name))
#             filtered_count += 1
#         else:
#             if avg_brightness < 50:
#                 too_dark_count += 1
#                 # print(f"过滤过暗图像: {img_name}, 平均亮度: {avg_brightness:.2f}")
#             else:
#                 too_bright_count += 1
#                 # print(f"过滤过亮图像: {img_name}, 平均亮度: {avg_brightness:.2f}")
# print(f"原始生成图像数量: {total_count}")
# print(f"筛选后保留的图像数量: {filtered_count}")
# print(f"过滤掉的过暗图像数量: {too_dark_count}")
# print(f"过滤掉的过亮图像数量: {too_bright_count}")

# # # 创建 临时目录用于存放存放筛选后的生成图像
# # filtered_gen_path = f'logs/eval/{exp_name}_filtered'
# # os.makedirs(filtered_gen_path, exist_ok=True)

# # def is_image_too_bright(image_path, max_avg_brightness=230):
# #     """判断图像是否过亮（发白）"""
# #     with Image.open(image_path).convert('L') as img:  # 转换为灰度图
# #         img_array = np.array(img)
# #         avg_brightness = np.mean(img_array)
# #         return avg_brightness > max_avg_brightness

# # 筛选生成图像
# # filtered_count = 0
# # total_count = 0

# # for img_name in os.listdir(generated_images_path):
# #     img_path = os.path.join(generated_images_path, img_name)
# #     # 只处理图像文件
# #     if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
# #         total_count += 1
# #         if not is_image_too_bright(img_path, max_avg_brightness=200):
# #             # 复制符合条件的图像到临时目录
# #             shutil.copy(img_path, os.path.join(filtered_gen_path, img_name))
# #             filtered_count += 1

# # print(f"原始生成图像数量: {total_count}")
# # print(f"筛选后保留的图像数量: {filtered_count}")
# # print(f"过滤掉的过亮图像数量: {total_count - filtered_count}")

# # 确保筛选后有足够图像用于计算FID
# if filtered_count < 2:
#     raise ValueError("筛选后图像数量不足，无法计算FID")

# # 使用 使用筛选后的生成图像计算FID
# fid_value = fid_score.calculate_fid_given_paths(
#     [real_images_path, filtered_gen_path],
#     batch_size=50,
#     device='cuda',
#     dims=2048
# )

# print(f'筛选后计算的FID score: {fid_value}')