dataset       = load("mnist-with-awgn.mat");
window_size   = 7;
x_num_windows = 28/window_size;
y_num_windows = 28/window_size;

A     = orth(A_chf);
inv_A = pinv(A);

X_test_images        = dataset.test_x;
denoised_test_images = zeros(10000*x_num_windows*y_num_windows, window_size*window_size);
itr = 1;
noise_power = 0.1;
for i=1:length(dataset.test_x)
    image  = dataset.test_x(i,:);
    image  = reshape(image, [28,28])';
    for x=1:x_num_windows
        for y=1:y_num_windows
            x_test          = reshape(image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
            x_test          = double(x_test);
            x_test          = x_test - mean_X;
            s_test          = inv_A*(x_test*principal_components)';
            % s_test          = sign(s_test).*max(0, (abs(s_test) - 25)/2 + 0.5*sqrt( (abs(s_test) + 25).^2 - 4*10 ) );
            s_test          = sign(s_test).*max(0, abs(s_test)-100);
            s_test          = real(s_test);
            denoised_x_test = (A*s_test)'*principal_components' + mean_X;
            denoised_test_images(itr,:) = denoised_x_test;
            itr = itr + 1;
        end
    end
    if(mod(i,1000) == 0)
        fprintf("%d test images denoised \n", i);
    end
end

%% See example denoised image

indices = [420];
for i=1:length(indices)
    index  = indices(i);
    denoised_index  = x_num_windows*y_num_windows*(index-1);
    original_image  = X_test_images(index,:);
    original_image  = reshape(original_image, [28,28])';

    itr = 1;
    denoised_image = zeros(28,28);
    for x=1:x_num_windows
        for y=1:y_num_windows
            original_flattened_subimg = reshape(original_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
            flattened_subimg = uint8(denoised_test_images(denoised_index+itr,:));
            itr = itr + 1;
            denoised_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y) = reshape(flattened_subimg, [window_size,window_size]);
        end
    end

    figure;
    subplot(1,2,1); imshow(original_image);
    subplot(1,2,2); imshow(uint8(denoised_image));
end