'''
import debugpy
try:
    # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
    debugpy.listen(("localhost", 9502))
    print("Waiting for debugger attach")
    debugpy.wait_for_client()
except Exception as e:
    pass
'''

import os
import glob
import argparse
import torch
import pandas as pd
from PIL import Image
from natsort import natsorted
from tqdm import tqdm
import lpips
from DISTS_pytorch import DISTS
import numpy as np





class MetricsCalculator:
    def __init__(self, device):
        self.device = device
        self._init_models()
        
    def _init_models(self):
        """初始化LPIPS和DISTS模型"""
        # LPIPS模型（AlexNet主干）
        self.lpips = lpips.LPIPS(net='alex', version='0.1').to(self.device)
        
        # DISTS模型
        self.dists = DISTS().to(self.device)
        
        # 设置为评估模式
        self.lpips.eval()
        self.dists.eval()
    
    def calculate_lpips(self, img1, img2):
        """计算LPIPS指标"""
         # 将PIL.Image转换为NumPy数组 (HWC格式, 范围0-255)
        img1_np = np.array(img1).astype(np.float32)
        img2_np = np.array(img2).astype(np.float32)
    
        # 转换为Tensor并归一化到[-1,1]
        img1_tensor = lpips.im2tensor(img1_np).to(self.device)
        img2_tensor = lpips.im2tensor(img2_np).to(self.device)
        
        # 计算相似度
        with torch.no_grad():
            score = self.lpips(img1_tensor, img2_tensor)
        return score.item()
    
    def calculate_dists(self, img1, img2):
        """计算DISTS指标"""
         # 将PIL.Image转换为NumPy数组（HWC格式，uint8类型）
        img1_np = np.array(img1)  # 形状 (H, W, 3), 值域 [0,255]
        img2_np = np.array(img2)
    
        # 转换为Tensor并调整通道顺序 (HWC -> CHW)
        img1_tensor = torch.from_numpy(img1_np).float().permute(2,0,1).unsqueeze(0) / 255.0  # 形状 (1, 3, H, W)
        img2_tensor = torch.from_numpy(img2_np).float().permute(2,0,1).unsqueeze(0) / 255.0
    
        # 传输到设备
        img1_tensor = img1_tensor.to(self.device)
        img2_tensor = img2_tensor.to(self.device)
    
        # 计算相似度
        with torch.no_grad():
            score = self.dists(img1_tensor, img2_tensor)
        return score.item()

def main(args):
    # 设备检测
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 初始化指标计算器
    metrics = MetricsCalculator(device)
    
    # 加载图像对
    image_pairs = load_image_pairs(args.foreground_dir, args.object_dir)
    
    # 初始化结果表
    results = []
    
    # 处理每个图像对
    for pair in tqdm(image_pairs, desc="Processing pairs"):
        try:
            # 计算指标
            lpips_score = metrics.calculate_lpips(pair['img1'], pair['img2'])
            dists_score = metrics.calculate_dists(pair['img1'], pair['img2'])
            
            # 记录结果
            results.append({
                'filename': pair['filename'],
                'LPIPS': lpips_score,
                'DISTS': dists_score
            })
        except Exception as e:
            print(f"处理文件 {pair['filename']} 时出错: {str(e)}")
    
    # 转换为DataFrame
    df = pd.DataFrame(results)
    
    # 保存结果
    save_results(df, args.output_dir)
    
    # 打印统计信息
    print("\n指标统计结果:")
    print(df[['LPIPS', 'DISTS']].describe())

def load_image_pairs(img_dir1, img_dir2):
    """加载匹配的图像对"""
    # 获取排序后的文件列表
    files1 = natsorted(glob.glob(os.path.join(img_dir1, "*.png")))
    files2 = natsorted(glob.glob(os.path.join(img_dir2, "*.png")))
    
    # 验证文件一致性
    assert len(files1) == len(files2), "两个目录的文件数量不一致"
    
    
    # 加载图像对
    pairs = []
    for f1, f2 in zip(files1, files2):
        try:
            img1 = Image.open(f1).convert('RGB')
            img2 = Image.open(f2).convert('RGB')
            
            # 统一尺寸为512x512
            img1 = img1.resize((512, 512))
            img2 = img2.resize((512, 512))
            
            pairs.append({
                'filename': os.path.basename(f1),
                'img1': img1,
                'img2': img2
            })
        except Exception as e:
            print(f"加载文件 {f1} 或 {f2} 失败: {str(e)}")
    
    return pairs

def save_results(df, output_dir):
    """保存结果到文件"""
    os.makedirs(output_dir, exist_ok=True)
    
    # 保存详细结果
    detail_path = os.path.join(output_dir, "detailed_results.csv")
    df.to_csv(detail_path, index=False)
    
    
    # 计算并保存均值统计
    mean_stats = pd.DataFrame(df.mean(numeric_only=True)).T
    mean_stats.index = ['mean']
    
    stats_path = os.path.join(output_dir, "summary_statistics.csv")
    mean_stats.to_csv(stats_path)
    
    print(f"\n结果已保存至目录: {output_dir}")
    print(f"详细结果: {detail_path}")
    print(f"统计摘要: {stats_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="图像质量评估工具")
    parser.add_argument('--foreground_dir', type=str, required=True,
                       help='前景图像目录路径')
    parser.add_argument('--object_dir', type=str, required=True,
                       help='目标图像目录路径')
    parser.add_argument('--output_dir', type=str, default="./results",
                       help='结果保存路径')
    
    args = parser.parse_args()
    
    main(args)