import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from torchmetrics import StructuralSimilarityIndexMeasure
from models.lpips import LPIPS
from tools.calculate_fid import calculate_fid_given_paths

# ---- 工具函数 ----
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32)

def calculate_psnr(original, processed):
    mse = torch.mean((original - processed) ** 2)
    return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()

def calculate_psnr_for_pair(original_path, processed_path):
    return calculate_psnr(load_image(original_path), load_image(processed_path))

def calculate_psnr_between_folders(original_folder, processed_folder):
    original_files = sorted(os.listdir(original_folder))
    processed_files = sorted(os.listdir(processed_folder))

    if len(original_files) != len(processed_files):
        print("Warning: Mismatched number of images in folders")
        return []

    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(calculate_psnr_for_pair,
                          os.path.join(original_folder, orig),
                          os.path.join(processed_folder, proc))
            for orig, proc in zip(original_files, processed_files)
        ]
        return [future.result() for future in as_completed(futures)]

# ---- 直接计算指标 ----
def evaluate_from_folders(ref_path, rec_path, device="cuda"):
    # FID
    fid = calculate_fid_given_paths([ref_path, rec_path], batch_size=50, dims=2048, device=device, num_workers=16)
    print(fid)
    # PSNR
    psnr_values = calculate_psnr_between_folders(ref_path, rec_path)
    avg_psnr = sum(psnr_values) / len(psnr_values)

    # LPIPS + SSIM
    lpips_model = LPIPS().to(device).eval()
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0)).to(device)

    lpips_scores = []
    ssim_scores = []

    ref_files = sorted(os.listdir(ref_path))
    rec_files = sorted(os.listdir(rec_path))
    for ref_file, rec_file in tqdm(zip(ref_files, rec_files), total=len(ref_files)):
        ref_img = load_image(os.path.join(ref_path, ref_file)).unsqueeze(0).to(device)
        rec_img = load_image(os.path.join(rec_path, rec_file)).unsqueeze(0).to(device)

        # [0,255] -> [-1,1] 以保持一致
        ref_norm = ref_img / 127.5 - 1.0
        rec_norm = rec_img / 127.5 - 1.0
        # import pdb; pdb.set_trace()
        lpips_scores.append(lpips_model(ref_norm, rec_norm).item())
        ssim_scores.append(ssim_metric(rec_img, ref_img).item())

    return {
        "FID": fid,
        "PSNR": avg_psnr,
        "LPIPS": float(np.mean(lpips_scores)),
        "SSIM": float(np.mean(ssim_scores)),
    }

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ref_path", type=str, required=True, help="路径：参考图像文件夹")
    parser.add_argument("--rec_path", type=str, required=True, help="路径：重建图像文件夹")
    args = parser.parse_args()

    results = evaluate_from_folders(args.ref_path, args.rec_path)
    print("Final Metrics:")
    for k, v in results.items():
        print(f"{k}: {v:.4f}")