'''
import debugpy
try:
    # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
    debugpy.listen(("localhost", 9501))
    print("Waiting for debugger attach")
    debugpy.wait_for_client()
except Exception as e:
    pass
'''

import sys
import os
import glob
import argparse
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from PIL import Image
from natsort import natsorted
from tqdm import tqdm

# 专用评估工具
import hpsv2
import ImageReward as RM
import clip
#from dinov2.models import vit_large
# （使用 Hugging Face Transformers）
from transformers import AutoModel, AutoConfig, AutoImageProcessor
from transformers import CLIPModel, CLIPProcessor

from dreamsim.dreamsim import dreamsim

# FID计算
from cleanfid import fid




def rle2mask(mask_rle, shape): # height, width
    starts, lengths = [np.asarray(x, dtype=int) for x in (mask_rle[0:][::2], mask_rle[1:][::2])]
    starts -= 1
    ends = starts + lengths
    binary_mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        binary_mask[lo:hi] = 1
    return binary_mask.reshape(shape)


class MetricsCalculator:
    def __init__(self, device,ckpt_path="data/ckpt") -> None:
        self.device=device

        self.imagereward_model = RM.load(name = "/HDD_data/xyc/reference_methods_Calculate_metrics/pretrain_model/models--THUDM--ImageReward/ImageReward.pt", med_config = "/HDD_data/xyc/reference_methods_Calculate_metrics/pretrain_model/models--THUDM--ImageReward/med_config.json") #("ImageReward-v1.0")
        

        # 新增模型初始化
        # 1. CLIP模型
        
        # 使用 transformers 库加载 CLIP 模型和处理器
        self.clip_model = CLIPModel.from_pretrained("/HDD_data/xyc/reference_methods_Calculate_metrics/pretrain_model/models--openai--clip-vit-base-patch32").to(device)
        self.clip_processor = CLIPProcessor.from_pretrained("/HDD_data/xyc/reference_methods_Calculate_metrics/pretrain_model/models--openai--clip-vit-base-patch32")

        

        # 2. DINOv2 模型（使用 Hugging Face Transformers）
        dino_path = "/HDD_data/xyc/reference_methods_Calculate_metrics/pretrain_model/models--facebook--dinov2-base"
        self.dino_processor = AutoImageProcessor.from_pretrained(dino_path)
        self.dino_model = AutoModel.from_pretrained(dino_path).to(device)
        self.dino_model.eval()


        # 3. DreamSim模型
        self.dreamsim_model, self.dreamsim_preprocess = dreamsim(pretrained=True, device=device, dreamsim_type="dino_vitb16") 


    def calculate_image_reward(self,image,prompt):
        reward = self.imagereward_model.score(prompt, [image])
        return reward

    def calculate_hpsv21_score(self,image,prompt):
        result = hpsv2.score(image, prompt)[0]  #去掉了hps_version="v2.1"
        return result.item()

    def calculate_clip_similarity(self, image1, image2):
        """
        计算CLIP图像相似度（基于CLIP视觉空间特征对齐）
        Ref: Radford et al. "Learning Transferable Visual Models From Natural Language Supervision" (2021)
        """
        # 使用 CLIPProcessor 预处理图像
        inputs1 = self.clip_processor(images=image1, return_tensors="pt").to(self.device)
        inputs2 = self.clip_processor(images=image2, return_tensors="pt").to(self.device)

        # 特征提取与归一化
        with torch.no_grad():
            feat1 = self.clip_model.get_image_features(**inputs1)
            feat2 = self.clip_model.get_image_features(**inputs2)

        # 归一化特征
        feat1 = feat1 / feat1.norm(dim=-1, keepdim=True)
        feat2 = feat2 / feat2.norm(dim=-1, keepdim=True)

        # 余弦相似度计算
        similarity = (feat1 * feat2).sum(dim=-1)
        return similarity.item()
    
        

    def calculate_dino_score(self, image, reference_image):
        """
        计算DINO-Score（图像间语义相似度）
        """
        # 使用DINO专用预处理
        def _process(img):
            return self.dino_processor(images=img, return_tensors="pt").pixel_values.to(self.device)
        
        # 批量处理双图
        with torch.no_grad():
            inputs = _process(image)
            ref_inputs = _process(reference_image)
            
            # 提取特征并取全局均值
            outputs = self.dino_model(inputs).last_hidden_state.mean(dim=1)
            ref_outputs = self.dino_model(ref_inputs).last_hidden_state.mean(dim=1)
            
        # 计算余弦相似度（保持原有接口）
        return torch.nn.functional.cosine_similarity(outputs, ref_outputs).item()

    def calculate_dreamsim(self, image1, image2):
        """
        计算DreamSim（人类感知相似度）[9,10](@ref)
        """

        
        # 预处理图像
        img1 = self.dreamsim_preprocess(image1).to(self.device)
        img2 = self.dreamsim_preprocess(image2).to(self.device)

        # 计算嵌入相似度
        with torch.no_grad():
            embedding1 = self.dreamsim_model.embed(img1)
            embedding2 = self.dreamsim_model.embed(img2)

        similarity = 1 - torch.nn.functional.cosine_similarity(embedding1, embedding2)
        return similarity.item()

    def calculate_fid_score(self, path1: str, path2: str) -> float:
        """
        计算两个图像文件夹之间的 FID 分数。
        """
        score = fid.compute_fid(path1, path2, mode="clean", device=self.device)
        return score



parser = argparse.ArgumentParser()


parser.add_argument('--foreground_dir', type=str, help='抠出来的前景图像目录路径')
parser.add_argument('--object_dir', type=str, help='目标图像目录路径')
parser.add_argument('--gen_dir', type=str, help='模型生成图像目录路径')
parser.add_argument('--text_dir', type=str, help='文本描述目录路径')
parser.add_argument('--uno_dir', type=str, help='UNO图像目录路径')
parser.add_argument('--image_save_path', type=str, help='保存的路径')



args = parser.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"

# 初始化文件列表
fg_files = natsorted(glob.glob(os.path.join(args.foreground_dir, "*.png")))
obj_files = natsorted(glob.glob(os.path.join(args.object_dir, "*.png")))
gen_files = natsorted(glob.glob(os.path.join(args.gen_dir, "*.png")))
txt_files = natsorted(glob.glob(os.path.join(args.text_dir, "*.txt")))

# 验证文件一致性
assert len(fg_files) == len(obj_files) == len(txt_files) == len(gen_files), "文件数量不一致"

# 预加载所有图像对
image_pairs = []
for fg_path, obj_path, gen_path, txt_path in zip(fg_files, obj_files, gen_files, txt_files):
    image_pairs.append({
        'fg': Image.open(fg_path).convert('RGB').resize((512,512)),
        'fg_filename': os.path.basename(fg_path),  # 显式记录文件名
        'obj': Image.open(obj_path).convert('RGB').resize((512,512)),
        'gen': Image.open(gen_path).convert('RGB').resize((512,512)),
        'prompt': open(txt_path).read().strip()
    })





# evaluation
evaluation_df = pd.DataFrame(columns=['filename', 'CLIP-Score','DINO-Score', 'dreamsim', 'HPS v2', 'ImageReward', 'FID'])  

metrics_calculator=MetricsCalculator(device)

 # 保持原有循环结构
for pair in tqdm(image_pairs, desc="Evaluating images", unit="image"): #for pair in image_pairs:
    evaluation_result = [pair['fg_filename']]  # 正确获取文件名

    # 保持原有指标循环
    for metric in evaluation_df.columns.values.tolist()[1:]:
        #print(f"Evaluating metric: {metric}")

        if metric == 'CLIP-Score':
            metric_result = metrics_calculator.calculate_clip_similarity(
                pair['fg'], pair['obj'])

        elif metric == 'DINO-Score':
            metric_result = metrics_calculator.calculate_dino_score(
                pair['fg'], pair['obj'])

        elif metric == 'dreamsim':
            metric_result = metrics_calculator.calculate_dreamsim(
                pair['fg'], pair['obj'])

        elif metric == 'HPS v2':
            metric_result = metrics_calculator.calculate_hpsv21_score(
                pair['gen'], pair['prompt'])

        elif metric == 'ImageReward':
            metric_result = metrics_calculator.calculate_image_reward(
                pair['gen'], pair['prompt'])

        else:  # FID
            metric_result = None  # 单独处理

        evaluation_result.append(metric_result)

    evaluation_df.loc[len(evaluation_df)] = evaluation_result

# 单独处理FID计算
if args.uno_dir:
    fid_score = metrics_calculator.calculate_fid_score(
        args.gen_dir, args.uno_dir)
    evaluation_df['FID'] = fid_score  # 全局填充



print("The averaged evaluation result:")
averaged_results=evaluation_df.mean(numeric_only=True)
print(averaged_results)
averaged_results.to_csv(os.path.join(args.image_save_path,"evaluation_result_sum.csv"))
evaluation_df.to_csv(os.path.join(args.image_save_path,"evaluation_result.csv"))

print(f"The generated images and evaluation results is saved in {args.image_save_path}")