% === 配置路径 ===
root_path = 'recons_mat';
gt_path = fullfile(root_path, 'pred_gt.mat');

% === 加载 GT 数据 ===
gt_data = load(gt_path);
gt_pred = double(gt_data.pred);  % [10, 256, 256, 28]

% === 波长定义 ===
lam28 = [453.5 457.5 462.0 466.0 471.5 476.5 481.5 487.0 492.5 498.0 ...
         504.0 510.0 516.0 522.5 529.5 536.5 544.0 551.5 558.5 567.5 ...
         575.5 584.5 594.5 604.0 614.5 625.0 636.5 648.0];

% === 输出目录 ===
save_root = fullfile('chroma_intens_decompo_reference');
if ~exist(save_root, 'dir'), mkdir(save_root); end

eps_val = 1e-6;

% === 获取所有 .mat 文件（排除 gt）===
mat_files = dir(fullfile(root_path, '*.mat'));
method_list = {};
for k = 1:length(mat_files)
    if ~strcmp(mat_files(k).name, 'pred_gt.mat') && ~strcmp(mat_files(k).name, 'pred_gt.mat') % 排除 gt
        method_name = strrep(mat_files(k).name, '.mat', ''); % 去掉扩展名
        method_list{end+1} = method_name;
    end
end

% 排序
method_list = sort(method_list);

fprintf('Found %d methods:\n', length(method_list));
for i = 1:length(method_list)
    fprintf('  %s\n', method_list{i});
end

% === 主循环：遍历每个 .mat 文件 ===
% === 主循环 ===
for m = 1:length(method_list)
    method = method_list{m};
    
    fprintf('\n=== Method: %s ===\n', method);

    mat_file = fullfile(root_path, [method '.mat']);
    data = load(mat_file);
    pred = double(data.pred);  % [10, 256, 256, 28]
    

    save_path = fullfile(save_root, method);
    if ~exist(save_path, 'dir'), mkdir(save_path); end

    % 初始化变量
    intensity_all = zeros(10, 256, 256);
    chroma_all = zeros(10, 256, 256, 28);
    psnr_I_all = zeros(10, 1);
    ssim_I_all = zeros(10, 1);
    psnr_C_all = zeros(10, 1);
    ssim_C_all = zeros(10, 1);

    % === 样本循环 ===
    for i = 1:10
        sample = squeeze(pred(i, :, :, :));
        sample = max(sample, 0); 
        %sample = max(sample ./ max(sample(:)), 0);  % Clip & normalize

        I = mean(sample, 3);
        I_expand = repmat(I, [1, 1, 28]);
        C = sample ./ (I_expand + eps_val);



        intensity_all(i, :, :) = I;
        chroma_all(i, :, :, :) = C;

        % normolize
        C = C./max(C(:));

        % 保存图像
        imwrite(mat2gray(I), fullfile(save_path, sprintf('intensity_%02d.png', i)));
        save_spectral_channels_as_color(C, save_path, sprintf('chroma%02d', i), lam28);

        % === 计算指标 ===
        if ~strcmp(method, 'gt')
            gt_sample = squeeze(gt_pred(i, :, :, :));
            I_gt = mean(gt_sample, 3);
            I_gt_expand = repmat(I_gt, [1, 1, 28]);
            C_gt = gt_sample ./ (I_gt_expand + eps_val);
            % normolize
            C_gt = C_gt./max(C_gt(:));

            % PSNR（uint8，统一255范围）
            psnr_I = compute_psnr(I, I_gt);
            ssim_I = compute_ssim(I, I_gt);

            psnr_C = 0; ssim_C = 0;
            for ch = 1:28
                psnr_C = psnr_C + psnr(C(:, :, ch), C_gt(:, :, ch)); %compute_psnr
                ssim_C = ssim_C + ssim(C(:, :, ch), C_gt(:, :, ch)); %compute_ssim
            end
            psnr_C = psnr_C / 28;
            ssim_C = ssim_C / 28;

            psnr_I_all(i) = psnr_I;
            ssim_I_all(i) = ssim_I;
            psnr_C_all(i) = psnr_C;
            ssim_C_all(i) = ssim_C;

            fprintf('Sample %2d | Intensity PSNR = %.2f, SSIM = %.4f | Chroma PSNR = %.2f, SSIM = %.4f\n', ...
                i, psnr_I, ssim_I, psnr_C, ssim_C);
        end
    end

    % 平均值汇总
    if ~strcmp(method, 'gt')
        fprintf('>>> [%s] Mean Intensity: PSNR = %.2f dB, SSIM = %.4f | Mean Chroma: PSNR = %.2f dB, SSIM = %.4f\n', ...
            method, mean(psnr_I_all), mean(ssim_I_all), mean(psnr_C_all), mean(ssim_C_all));
    end

    % 保存变量
    save(fullfile(save_path, 'chroma_inten.mat'), ...
        'intensity_all', 'chroma_all', ...
        'psnr_I_all', 'ssim_I_all', ...
        'psnr_C_all', 'ssim_C_all');
end

% === 函数：PSNR ===
function val = compute_psnr(img1, img2)
    img1 = uint8(min(max(img1 * 255, 0), 255));
    img2 = uint8(min(max(img2 * 255, 0), 255));
    mse = mean((double(img1(:)) - double(img2(:))).^2);
    val = 10 * log10(255^2 / mse);
end

% === 函数：SSIM ===
function val = compute_ssim(img1, img2)
    img1 = uint8(min(max(img1 * 255, 0), 255));
    img2 = uint8(min(max(img2 * 255, 0), 255));
    val = ssim(img1, img2);
end