import os
import cv2
from skimage.metrics import structural_similarity as ssim
import lpips
import torch

# 定义路径和初始化
exp2_path = 'exp1'
models = ['amon', 'car', 'cherry', 'cone', 'icecream', 'minion', 'sculpture', 'stool', 'tong', 'football']  # 替换为您的模型名称列表
input_names = ['amon', 'car_forget', 'cherry', 'cone', 'icecream', 'minion', 'sculpture', 'stool_forget', 'tong', 'football']  # 替换为您的模型名称列表
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn_alex = lpips.LPIPS(net='alex').to(device)  # 使用AlexNet版本的LPIPS


def load_image(path):
    """加载并预处理图像"""
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return torch.tensor(img).permute(2, 0, 1).float().div(255.).unsqueeze(0).to(device)


def calculate_metrics(input_name):
    results = {}
    for model in models:
        gen_img_path = os.path.join(exp2_path, f'{model}_case-{input_name}_sample_angle_0.png')
        gt_img_path = os.path.join(exp2_path, f'case/car_forget.png')

        if not (os.path.exists(gen_img_path) and os.path.exists(gt_img_path)):
            print(f"找不到文件: {gen_img_path} 或 {gt_img_path}")
            continue

        gen_img = load_image(gen_img_path)
        gt_img = load_image(gt_img_path)

        # 计算SSIM
        gen_img_np = gen_img.squeeze().permute(1, 2, 0).cpu().numpy()
        gt_img_np = gt_img.squeeze().permute(1, 2, 0).cpu().numpy()
        ssim_val = ssim(gt_img_np, gen_img_np, multichannel=True, window_size=11, channel_axis=-1, data_range=1.0)

        # 计算LPIPS
        with torch.no_grad():
            lpips_val = loss_fn_alex(gen_img, gt_img)

        results[model] = {'ssim': ssim_val, 'lpips': lpips_val.item()}

    return results

def calculate_metrics2():
    results = {}
    ssim_avg = 0
    lpips_avg = 0
    # for i in range(len(models)):
    for angle in [0, 45, 90, 135, 180, -45, -90, -135]:
    # for model in models:
        gen_img_path = os.path.join(exp2_path, f'banana_sample_angle_{angle}.png')
        # gen_img_path = os.path.join(exp2_path, f'minion_case-minion_sample_angle_180.png')
        # gt_img_path = os.path.join(exp2_path, f'override/minion_override.png')
        # gt_img_path = os.path.join(exp2_path, f'override/{models[i]}_override.png')
        gt_img_path = f'D:\\cache_sfd\\real_object\\{angle}.png'

        if not (os.path.exists(gen_img_path) and os.path.exists(gt_img_path)):
            print(f"找不到文件: {gen_img_path} 或 {gt_img_path}")
            continue

        gen_img = load_image(gen_img_path)
        gt_img = load_image(gt_img_path)

        # 计算SSIM
        gen_img_np = gen_img.squeeze().permute(1, 2, 0).cpu().numpy()
        gt_img_np = gt_img.squeeze().permute(1, 2, 0).cpu().numpy()
        ssim_val = ssim(gt_img_np, gen_img_np, multichannel=True, window_size=11, channel_axis=-1, data_range=1.0)

        # 计算LPIPS
        with torch.no_grad():
            lpips_val = loss_fn_alex(gen_img, gt_img)

        # results[model] = {'ssim': ssim_val, 'lpips': lpips_val.item()}
        print(f'angle{angle}: ssim: {ssim_val}, lpips: {lpips_val.item()}')
        ssim_avg += ssim_val
        lpips_avg += lpips_val.item()

    print(f'final ssim: {ssim_avg / 8}, lpips: {lpips_avg / 8}')

    return results

calculate_metrics2()
# # 示例：计算特定 input_name 的所有模型的 SSIM 和 LPIPS
# for input_name in input_names:
#     print(input_name)
#     # input_name = 'example_input'  # 替换为您的输入图像名
#     metrics = calculate_metrics(input_name)
#     for model, values in metrics.items():
#         print(f'Model: {model}, SSIM: {values["ssim"]:.4f}, LPIPS: {values["lpips"]:.4f}')

# import os
# import cv2
# from skimage.metrics import structural_similarity as ssim
# import lpips
# import torch
# import pandas as pd
#
# # 定义路径和初始化
# exp2_path = 'exp2'
# models = ['car', 'icecream', 'sculpture', 'cherry', 'cone', 'football', 'tong', 'amon', 'minion',  'stool']  # 替换为您的模型名称列表
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# loss_fn_alex = lpips.LPIPS(net='alex').to(device)  # 使用AlexNet版本的LPIPS
#
#
# def load_image(path):
#     """加载并预处理图像"""
#     img = cv2.imread(path)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     return torch.tensor(img).permute(2, 0, 1).float().div(255.).unsqueeze(0).to(device)
#
#
# def calculate_metrics(input_names):
#     results = {}
#     for input_name in input_names:
#         for model in models:
#             gen_img_path = os.path.join(exp2_path, f'{model}_case-{input_name}_sample_angle_0.png')
#             gt_img_path = os.path.join(exp2_path, f'case/car_forget.png')
#
#             if not (os.path.exists(gen_img_path) and os.path.exists(gt_img_path)):
#                 print(f"找不到文件: {gen_img_path} 或 {gt_img_path}")
#                 continue
#
#             gen_img = load_image(gen_img_path)
#             gt_img = load_image(gt_img_path)
#
#             # 计算SSIM
#             gen_img_np = gen_img.squeeze().permute(1, 2, 0).cpu().numpy()
#             gt_img_np = gt_img.squeeze().permute(1, 2, 0).cpu().numpy()
#             ssim_val = ssim(gt_img_np, gen_img_np, multichannel=True, win_size=11, channel_axis=-1, data_range=1.0)
#
#             # 计算LPIPS
#             with torch.no_grad():
#                 lpips_val = loss_fn_alex(gen_img, gt_img)
#
#             if model not in results:
#                 results[model] = []
#             results[model].append({'case': input_name, 'SSIM': ssim_val, 'LPIPS': lpips_val.item()})
#
#     return results
#
#
# # 示例：计算特定 input_name 的所有模型的 SSIM 和 LPIPS
# # input_names = ['amon', 'car_forget', 'cherry', 'cone', 'icecream', 'minion', 'sculpture', 'stool_forget', 'tong', 'football']  # 替换为您的输入图像名列表
# input_names = ['car_forget', 'icecream', 'sculpture', 'cherry', 'cone', 'football', 'tong', 'amon', 'minion',  'stool_forget']
# results = calculate_metrics(input_names)
#
# # 创建并保存CSV文件
# for model, metrics_list in results.items():
#     df = pd.DataFrame(metrics_list)
#     df.to_csv(f'{model}_metrics.csv', index=False)
#
# print("CSV files have been created successfully.")