import os
import torch
import cv2
import numpy as np
import re
import lpips
from PIL import Image, ExifTags
from tqdm import tqdm
from torchmetrics.multimodal.clip_score import CLIPScore
import ImageReward as RM
from torch.utils.data import DataLoader
import torchvision.transforms.v2.functional as TF
import torchvision.transforms.v2 as T
from skimage.metrics import structural_similarity as ssim
import argparse
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent
MODELS_ROOT = REPO_ROOT / "models"
DEFAULT_CLIP_MODEL = "openai/clip-vit-large-patch14"
DEFAULT_REWARD_DIR = MODELS_ROOT / "ImageReward-v1.0"

os.environ.setdefault('TOKENIZERS_PARALLELISM', 'false')
os.environ.setdefault('HF_HUB_OFFLINE', '0')  # 允许使用本地缓存
os.environ.setdefault('HF_DATASETS_OFFLINE', '0')  # 允许使用本地缓存
os.environ.setdefault('TRANSFORMERS_OFFLINE', '0')  # 允许使用本地缓存
os.environ.setdefault('HF_HOME', str(MODELS_ROOT / '.huggingface'))  # 设置 Hugging Face 缓存目录
os.environ.setdefault('HF_HUB_DISABLE_PROGRESS_BARS', '1')  # 禁用进度条

def load_images_and_prompts(image_folder):
    """加载图像，并从其 Exif 元数据中提取提示（ImageDescription 字段）。"""
    images = []
    prompts = []
    for filename in os.listdir(image_folder):
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            img_path = os.path.join(image_folder, filename)
            try:
                img = Image.open(img_path)
                exif_data = img._getexif()
                if exif_data is not None:
                    exif = {}
                    for tag_id, value in exif_data.items():
                        tag = ExifTags.TAGS.get(tag_id, tag_id)
                        exif[tag] = value
                    prompt = exif.get("ImageDescription", None)
                    if prompt:
                        images.append(img_path)
                        prompts.append(prompt)
                    else:
                        print(f"图像 {filename} 中未找到提示信息。")
                else:
                    print(f"图像 {filename} 中无 Exif 元数据。")
            except Exception as e:
                print(f"处理图像 {filename} 时出错：{e}")
    return images, prompts

class ImageFolder(torch.utils.data.Dataset):
    def __init__(self, image_folder):
        self.image_folder = image_folder
        self.images, self.prompts = load_images_and_prompts(image_folder)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        prompt = self.prompts[idx]
        img = Image.open(img_path).convert("RGB")
        img = TF.pil_to_tensor(img)
        return {
            'image': img,
            'prompt': prompt,
            'path': img_path
        }

def get_reward_transform():
    return T.Compose([
        T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(224),
        T.ToImage(),
        T.ToDtype(torch.float32, scale=True),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


def init_clip_score(model_name_or_path: str, device: torch.device) -> CLIPScore:
    """Instantiate CLIPScore with support for:
    - torchmetrics versions without `local_files_only` kwarg
    - providing a local directory via a Callable factory (since some versions only accept known repo IDs)
    """
    # If a local directory is provided, wrap loaders in a Callable factory accepted by torchmetrics
    if os.path.isdir(model_name_or_path) and (
        os.path.isfile(os.path.join(model_name_or_path, "config.json"))
        or os.path.isfile(os.path.join(model_name_or_path, "preprocessor_config.json"))
    ):
        local_dir = model_name_or_path

        def _factory():
            from transformers import CLIPModel as _CLIPModel, CLIPProcessor as _CLIPProcessor

            # Loading from a local directory does not hit network
            try:
                model = _CLIPModel.from_pretrained(local_dir, local_files_only=True)
            except TypeError:
                model = _CLIPModel.from_pretrained(local_dir)
            try:
                processor = _CLIPProcessor.from_pretrained(local_dir, local_files_only=True)
            except TypeError:
                processor = _CLIPProcessor.from_pretrained(local_dir)
            return model, processor

        kwargs = {"model_name_or_path": _factory}
    else:
        kwargs = {"model_name_or_path": model_name_or_path}

    try:
        metric = CLIPScore(**kwargs, local_files_only=True)
    except (TypeError, ValueError) as err:
        if "local_files_only" not in str(err):
            raise
        metric = CLIPScore(**kwargs)
    return metric.to(device)

def calculate_psnr(img1, img2):
    """计算PSNR(Peak Signal-to-Noise Ratio)峰值信噪比"""
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def calculate_ssim(img1, img2):
    """计算SSIM(Structural Similarity Index)结构相似性指数"""
    if len(img1.shape) == 3 and img1.shape[2] == 3:
        gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        return ssim(gray1, gray2, data_range=255)
    else:
        return ssim(img1, img2, data_range=255)

def preprocess_image_for_lpips(img):
    """预处理图像用于LPIPS计算"""
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32) / 255.0
    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    img = img * 2 - 1
    return img

def extract_number(filename):
    """从文件名中提取数字部分"""
    match = re.search(r'(\d+)', filename)
    if match:
        return match.group(1)
    return None

def find_matching_original_image(test_image_path, original_folder):
    """根据文件名中的数字找到对应的原始图像"""
    test_filename = os.path.basename(test_image_path)
    test_number = extract_number(test_filename)
    
    if not test_number:
        return None
    
    img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
    
    for filename in os.listdir(original_folder):
        if os.path.splitext(filename.lower())[1] in img_extensions:
            if extract_number(filename) == test_number:
                return os.path.join(original_folder, filename)
    
    return None

def comprehensive_evaluation(
    test_folder,
    original_folder,
    batch_size=64,
    num_workers=32,
    clip_model=str(DEFAULT_CLIP_MODEL),
    reward_model_path=str(DEFAULT_REWARD_DIR / "ImageReward.pt"),
    reward_med_config=str(DEFAULT_REWARD_DIR / "med_config.json"),
    reward_model_name="ImageReward-v1.0",
    use_gpu=True,
    save_to_test_folder=True,
):
    """综合评估函数，计算所有指标"""
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not torch.cuda.is_available():
        raise RuntimeError("未检测到 GPU！")
    print(f"当前使用设备：{device}")
    
    # 加载模型
    reward_model_path = Path(reward_model_path) if reward_model_path else None
    reward_med_config_path = Path(reward_med_config) if reward_med_config else None

    reward_kwargs = {'device': device}

    if reward_model_path and reward_model_path.is_file():
        reward_kwargs['download_root'] = str(reward_model_path.parent)
        if reward_med_config_path and reward_med_config_path.is_file():
            reward_kwargs['med_config'] = str(reward_med_config_path)
        reward_identifier = str(reward_model_path)
    else:
        reward_kwargs['download_root'] = str(DEFAULT_REWARD_DIR)
        if reward_med_config_path and reward_med_config_path.is_file():
            reward_kwargs['med_config'] = str(reward_med_config_path)
        reward_identifier = reward_model_name

    reward_model = RM.load(reward_identifier, **reward_kwargs).to(device)
    reward_transform = get_reward_transform()
    print(f"使用 CLIP 模型: {clip_model}")
    clip_score_metric = init_clip_score(str(clip_model), device)
    lpips_model = lpips.LPIPS(net='alex', verbose=False).to(device)
    
    print(f"已加载 ImageReward 模型 '{reward_identifier}' 到 {device}")
    
    # 加载测试图像及对应提示
    data = ImageFolder(test_folder)
    if len(data) == 0:
        print("测试文件夹中没有找到有效的图像文件")
        return None
    
    dl = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    # 存储所有评分
    reward_scores = []
    psnr_values = []
    ssim_values = []
    lpips_values = []
    valid_pairs = 0
    
    print(f"开始处理 {len(data)} 张图像...")
    
    # 分批处理CLIP Score和ImageReward
    for batch in (pbar := tqdm(dl, desc="处理批次")):
        images = batch['image'].to(device)
        prompts = batch['prompt']
        image_paths = batch['path']
        
        # 计算CLIP Score和ImageReward
        inputs = reward_model.blip.tokenizer(prompts, padding='max_length', truncation=True, max_length=512, return_tensors="pt").to(device)
        with torch.no_grad():
            clip_score_metric.update(images, prompts)
            images_reward = reward_transform(images)
            scores = reward_model.score_gard(inputs.input_ids, inputs.attention_mask, images_reward)
            reward_scores.extend(scores.tolist())
        
        # 计算PSNR、SSIM和LPIPS（需要逐个处理）
        for i, test_image_path in enumerate(image_paths):
            original_image_path = find_matching_original_image(test_image_path, original_folder)
            
            if original_image_path is None:
                print(f"未找到与 {os.path.basename(test_image_path)} 匹配的原始图像")
                continue
            
            # 读取图像
            test_img = cv2.imread(test_image_path)
            original_img = cv2.imread(original_image_path)
            
            if test_img is None or original_img is None:
                print(f"无法读取图像对: {test_image_path} 或 {original_image_path}")
                continue
            
            # 确保图像尺寸相同
            if test_img.shape != original_img.shape:
                test_img = cv2.resize(test_img, (original_img.shape[1], original_img.shape[0]))
            
            # 计算PSNR
            psnr_value = calculate_psnr(original_img, test_img)
            psnr_values.append(psnr_value)
            
            # 计算SSIM
            ssim_value = calculate_ssim(original_img, test_img)
            ssim_values.append(ssim_value)
            
            # 计算LPIPS
            with torch.no_grad():
                original_tensor = preprocess_image_for_lpips(original_img).to(device)
                test_tensor = preprocess_image_for_lpips(test_img).to(device)
                lpips_value = float(lpips_model(original_tensor, test_tensor).item())
                lpips_values.append(lpips_value)
            
            valid_pairs += 1
    
    # 计算平均值
    overall_clip_avg = clip_score_metric.compute().item()
    overall_reward_avg = np.mean(reward_scores) if reward_scores else 0
    avg_psnr = np.mean(psnr_values) if psnr_values else 0
    avg_ssim = np.mean(ssim_values) if ssim_values else 0
    avg_lpips = np.mean(lpips_values) if lpips_values else 0
    
    # 打印结果
    print("\n" + "="*50)
    print("综合评估结果:")
    print("="*50)
    print(f"处理的图像总数: {len(data)}")
    print(f"成功配对的图像数: {valid_pairs}")
    print(f"平均 CLIP Score: {overall_clip_avg:.4f}")
    print(f"平均 ImageReward Score: {overall_reward_avg:.4f}")
    print(f"平均 PSNR: {avg_psnr:.3f}")
    print(f"平均 SSIM: {avg_ssim:.4f}")
    print(f"平均 LPIPS: {avg_lpips:.4f}")
    print("="*50)
    
    results = {
        "clip_score": overall_clip_avg,
        "image_reward": overall_reward_avg,
        "psnr": avg_psnr,
        "ssim": avg_ssim,
        "lpips": avg_lpips,
        "total_images_processed": len(data),
        "valid_pairs": valid_pairs,
    }
    
    # 输出JSON格式的结果
    import json
    print("\n" + "="*50)
    print("JSON结果:")
    print("="*50)
    print(json.dumps(results, indent=2))
    print("="*50)
    
    # 如果指定保存到测试目录，则保存结果文件
    if save_to_test_folder:
        test_folder_path = Path(test_folder)
        results_file = test_folder_path / "evaluation_results.json"
        
        # 添加时间戳和额外信息
        from datetime import datetime
        results_with_meta = {
            "evaluation_timestamp": datetime.now().isoformat(),
            "test_folder": str(test_folder),
            "original_folder": str(original_folder),
            "evaluation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            **results
        }
        
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(results_with_meta, f, indent=2, ensure_ascii=False)
        
        print(f"\n[INFO] 评估结果已保存到: {results_file}")
    
    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="综合评估脚本：计算 CLIP Score、ImageReward、PSNR、SSIM 和 LPIPS")
    parser.add_argument("--test_folder", type=str, default="/root/autodl-tmp/TaylorSeer/TaylorSeer-FLUX/bdf2/interval6", help="测试图像文件夹的路径")
    parser.add_argument("--original_folder", type=str, default="results/FLUX-DEV-50", help="原始图像文件夹的路径")
    parser.add_argument("--batch_size", type=int, default=64, help="批处理大小")
    parser.add_argument("--num_workers", type=int, default=32, help="数据加载的工作线程数")
    parser.add_argument("--clip_model", type=str, default=str(DEFAULT_CLIP_MODEL), help="CLIP 模型路径或名称")
    parser.add_argument("--reward_model_path", type=str, default=str(DEFAULT_REWARD_DIR / "ImageReward.pt"), help="本地 ImageReward checkpoint 路径")
    parser.add_argument("--reward_med_config", type=str, default=str(DEFAULT_REWARD_DIR / "med_config.json"), help="ImageReward 所需的 med_config.json 路径")
    parser.add_argument("--reward_model_name", type=str, default="ImageReward-v1.0", help="ImageReward 模型名称 (当未提供本地路径时使用)")
    parser.add_argument("--gpu", action='store_true', default=True, help="使用GPU加速计算")
    parser.add_argument("--save-to-test-folder", action='store_true', default=True, help="将评估结果保存到测试目录内的evaluation_results.json文件 (默认开启)")
    
    args = parser.parse_args()
    
    comprehensive_evaluation(
        test_folder=args.test_folder,
        original_folder=args.original_folder,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        clip_model=args.clip_model,
        reward_model_path=args.reward_model_path,
        reward_med_config=args.reward_med_config,
        reward_model_name=args.reward_model_name,
        use_gpu=args.gpu,
        save_to_test_folder=args.save_to_test_folder,
    )
