import os, glob, re
import numpy as np
from PIL import Image
import torch
from pytorch_msssim import ms_ssim as ms_ssim_torch, ssim as ssim_torch

def psnr(x, y, max_val=1.0):
    mse = np.mean((x - y) ** 2)
    if mse == 0:
        return float('inf'), 0.0
    psnr_val = 10.0 * np.log10((max_val ** 2) / mse)
    return psnr_val, mse

def ms_ssim_np(x, y):
    x_t = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0).contiguous()
    y_t = torch.from_numpy(y).permute(2, 0, 1).unsqueeze(0).contiguous()

    x_t = x_t.float()
    y_t = y_t.float()

    with torch.no_grad():
        val = ms_ssim_torch(x_t, y_t, data_range=1.0)
    return float(val.item())

def ssim_np(x, y):
    x_t = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0).contiguous()
    y_t = torch.from_numpy(y).permute(2, 0, 1).unsqueeze(0).contiguous()

    x_t = x_t.float()
    y_t = y_t.float()

    with torch.no_grad():
        val = ssim_torch(x_t, y_t, data_range=1.0)
    return float(val.item())

def load_rgb01(path):
    arr = np.asarray(Image.open(path).convert("RGB")).astype(np.float32) / 255.0
    return arr

def softmax_np(x):
    x = x - np.max(x)
    ex = np.exp(x)
    return ex / (np.sum(ex) + 1e-12)

def fuse_images_denoise_from_folder(
    folder,
    arr_gt,
    y_path,            
    output_path,
    index,
    mask_list=None,
    seed_list=None,
    tau=50.0,           
    lam=1.0             
):
    pattern = os.path.join(folder, "stage_2_*.png")
    all_files = sorted(glob.glob(pattern))

    name_regex = re.compile(r"stage_2_(\d+)_mask_([0-9.]+)_seed_(\d+)\.png")

    selected_files = []
    for f in all_files:
        base = os.path.basename(f)
        m = name_regex.match(base)
        if not m:
            continue
        idx_str, mask_str, seed_str = m.groups()
        mask_val = float(mask_str)
        seed_val = int(seed_str)

        if mask_list is not None and mask_val not in mask_list:
            continue
        if seed_list is not None and seed_val not in seed_list:
            continue
        selected_files.append(f)

    if len(selected_files) == 0:
        print("No files selected after filtering by mask/seed.")
        return None, None

    print(f"Total files: {len(all_files)}, selected: {len(selected_files)}")
    for f in selected_files:
        print("  use:", os.path.basename(f))

    y = load_rgb01(y_path)

    xs = []
    errs = []
    for f in selected_files:
        x = load_rgb01(f)
        xs.append(x)
        e = np.mean((x - y) ** 2)
        errs.append(e)

    errs = np.array(errs, dtype=np.float64)  # [K]
    logits = -tau * (errs - errs.min())
    gamma = softmax_np(logits)              

    print("gating gamma:", gamma)
    print("consistency errs:", errs)

    x_bar = np.zeros_like(xs[0], dtype=np.float32)
    for w, x in zip(gamma, xs):
        x_bar += (w * x).astype(np.float32)

    x_fuse = (y + lam * x_bar) / (1.0 + lam)

    x_fuse = np.clip(x_fuse, 0.0, 1.0)

    out_psnr, out_mse = psnr(x_fuse, arr_gt)
    out_ssim = ssim_np(x_fuse, arr_gt)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    Image.fromarray((x_fuse * 255.0).astype(np.uint8)).save(output_path)

    print(f"\nFused image saved to: {output_path}")
    print(f'The {index+1:d} th Fused vs GT PSNR:   {out_psnr:.4f} dB  (MSE={out_mse:.6e})')
    print(f'The {index+1:d} th Fused vs GT SSIM:   {out_ssim:.6f}')

    return out_psnr, out_ssim


if __name__ == "__main__":
    all_psnr_list=[]
    all_ssim_list=[]
    for it in range(0,24):
        gt_path = f'./kodim{it+1:02d}.png'
        arr_gt = load_rgb01(gt_path)
        target_folder = f"./re_kodim_{it+1:02d}/32_50000/"
        noisy_folder = f"./noisy_kodim_{it+1:02d}/"

        mask_list = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
        seed_list = [20,40,60,80,100]

        y_path = os.path.join(noisy_folder, f"noisy_input_{it}_15.png") 
        out_psnr, out_ssim = fuse_images_denoise_from_folder(
            folder=target_folder,arr_gt=arr_gt,
            y_path=y_path,output_path=f'./Rec_15/{it+1:02d}.png',
            index=it,mask_list=mask_list,seed_list=seed_list,tau=50,lam=30)
        all_psnr_list.append(out_psnr)
        all_ssim_list.append(out_ssim)
    print('Final psnr list',all_psnr_list)
    print('Final SSIM list',all_ssim_list)
    print('Average PSNR',np.array(all_psnr_list).mean())
    print('Average SSIM',np.array(all_ssim_list).mean())

for psnr_, ssim_ in zip(all_psnr_list, all_ssim_list):
    print(f"{psnr_:.2f}/{ssim_:.4f}", end=", ")
    