

clc; close all; clear all;
pkg load image

function ssim_mean=compute_ssim(img1,img2)
    if size(img1, 3) == 3
        img1 = rgb2ycbcr(img1);
        img1 = img1(:, :, 1);
    end

    if size(img2, 3) == 3
        img2 = rgb2ycbcr(img2);
        img2 = img2(:, :, 1);
    end
    ssim_mean = SSIM_index(img1, img2);
end

function psnr=compute_psnr(img1,img2)
    if size(img1, 3) == 3
        img1 = rgb2ycbcr(img1);
        img1 = img1(:, :, 1);
    end

    if size(img2, 3) == 3
        img2 = rgb2ycbcr(img2);
        img2 = img2(:, :, 1);
    end

    imdff = double(img1) - double(img2);
    imdff = imdff(:);
    rmse = sqrt(mean(imdff.^2));
    psnr = 20*log10(255/rmse);
end

function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L)

    if (nargin < 2 || nargin > 5)
       mssim = -Inf;
       ssim_map = -Inf;
       return;
    end

    if (size(img1) ~= size(img2))
       mssim = -Inf;
       ssim_map = -Inf;
       return;
    end

    [M,N] = size(img1);

    if (nargin < 3)
       window = fspecial('gaussian', 11, 1.5);
       K = [0.01 0.03];
       L = 255;
    elseif nargin < 4
       window = fspecial('gaussian', 11, 1.5);
       L = 255;
    elseif nargin < 5
       L = 255;
    end

    C1 = (K(1)*L)^2;
    C2 = (K(2)*L)^2;
    window = window / sum(window(:));
    img1 = double(img1);
    img2 = double(img2);

    mu1   = filter2(window, img1, 'valid');
    mu2   = filter2(window, img2, 'valid');
    mu1_sq = mu1.*mu1;
    mu2_sq = mu2.*mu2;
    mu1_mu2 = mu1.*mu2;
    sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
    sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
    sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;

    if (C1 > 0 && C2 > 0)
       ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2)) ./ ((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
    else
       numerator1 = 2*mu1_mu2 + C1;
       numerator2 = 2*sigma12 + C2;
       denominator1 = mu1_sq + mu2_sq + C1;
       denominator2 = sigma1_sq + sigma2_sq + C2;
       ssim_map = ones(size(mu1));
       index = (denominator1.*denominator2 > 0);
       ssim_map(index) = (numerator1(index).*numerator2(index)) ./ (denominator1(index).*denominator2(index));
       index = (denominator1 ~= 0) & (denominator2 == 0);
       ssim_map(index) = numerator1(index) ./ denominator1(index);
    end

    mssim = mean2(ssim_map);
end



datasets = {'SID'};
num_set = length(datasets);

psnr_alldatasets = 0;
ssim_alldatasets = 0;

tic

for idx_set = 1:num_set
    file_path = ''; 
    gt_path = '';

    path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
    img_num = length(path_list);

    total_psnr = 0;
    total_ssim = 0;

    if img_num > 0 
        for j = 1:img_num 
           image_name = path_list(j).name;          % char
           
           % 使用 regexp 分割后用 {} 取出字符串内容（而非 cell）
           gt_name_fen = regexp(image_name, '_', 'split');  
           gt_name = [gt_name_fen{1} '.png'];    % 例如 'abc.png'
           
           % 读取输入图像（原图）
           input = imread(fullfile(file_path, image_name));
           
           % 调试输出


           full_gt_path = fullfile(gt_path, gt_name);

           
           % 确保传入 imread 的是 char
           gt = imread(char(full_gt_path));
           
           [M,N,C] = size(input);  
           
           % Octave 下 imresize 替代方案（如果 image 包不可用）
           if exist('imresize', 'file')
               gt = imresize(gt, [M,N]);
           else
               % 使用 interp2 进行双线性插值
               for c = 1:size(gt,3)
                   [X, Y] = meshgrid(1:size(gt,2), 1:size(gt,1));
                   [XI, YI] = meshgrid(linspace(1,size(gt,2),N), linspace(1,size(gt,1),M));
                   gt(:,:,c) = uint8(interp2(X,Y,double(gt(:,:,c)),XI,YI,'linear'));
               end
           end
           
           %%%%
           ssim_val = compute_ssim(input, gt);
           psnr_val = compute_psnr(input, gt);
           total_ssim = total_ssim + ssim_val;
           total_psnr = total_psnr + psnr_val;
           disp([num2str(idx_set) '-dataset:' num2str(j)]);
       end
    end
    qm_psnr = total_psnr / img_num;
    qm_ssim = total_ssim / img_num;

    fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);

    psnr_alldatasets = psnr_alldatasets + qm_psnr;
    ssim_alldatasets = ssim_alldatasets + qm_ssim;
end

fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set);

toc

%% ----- Functions -----
