clc; clear;
addpath(genpath('./FastICA_25/'));

%% Load data

imagepath          = "ImagesForDenoising/lena_gray.png";
original_image     = imread(imagepath);

%% Add noise

w = 512; 
h = 512;
patch_x    = 200;
patch_y    = 200;
patch_size = 50;
x_window   = patch_x-patch_size:patch_x+patch_size;
y_window   = patch_y-patch_size:patch_y+patch_size;

% noisy_image = double(original_image) + 50*randn(w,h);
noisy_image = double(original_image);
noisy_image(x_window, y_window) = noisy_image(x_window, y_window) + ...
                                 50*randn(2*patch_size+1,2*patch_size+1);

%% Generate data by considering sub-images

window_size   = 4;
x_num_windows = h/window_size;
y_num_windows = w/window_size;
X             = zeros(x_num_windows*y_num_windows, window_size*window_size);
itr           = 1;
for x=1:x_num_windows
    for y=1:y_num_windows
        X(itr, :) = reshape(noisy_image(window_size*(x-1)+1:window_size*x, ...
                    window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
        itr = itr + 1;
    end
end

mean_X  = mean(X,1);
X       = X - mean_X;
k       = window_size*window_size;

%% Run algorithms

A_jade                = jade(X',k); fprintf("JADE completed\n");
[A_fastica,~]         = fastica(X','numOfIC',k); fprintf("FASTICA completed\n");
[x_fica,y_fica]       = size(A_fastica);
A_fastica_new         = zeros(k,k);
A_fastica_new(1:x_fica,1:y_fica) = A_fastica;
A_fastica  = A_fastica_new;
cgf_fn                = cgf(X);
chf_fn                = symmetric_chf(X);
kurtosis_fn           = kurtosis(X);
C_cgf                 = A_jade*A_jade';
C_chf                 = A_jade*A_jade';
C_kurtosis            = kurtosis_fn.estimate_C(20);
pinv_C_cgf            = pinv(C_cgf);
pinv_C_chf            = pinv(C_chf);
pinv_C_kurtosis       = pinv(C_kurtosis);

A_cgf                 = zeros(k,k);
B1_cgf                = zeros(k,k);
A_chf                 = zeros(k,k);
B1_chf                = zeros(k,k);
A_kurtosis            = zeros(k,k);
B1_kurtosis           = zeros(k,k);
A_meta                = zeros(k,k);
B1_meta               = zeros(k,k);
verbose_flag          = 0;
maxruns               = 20;
num_trials            = 200;
random_projections    = randn(2*num_trials,k);

for i=1:k
    fprintf('Column %d\n',i);
    u_init = randn(k,1);

    fprintf('CGF run\n');
    M_cgf = eye(k) - A_cgf*B1_cgf;
    [u_cgf,~,~,~] = ICA_power(X, ...
                              maxruns, ...
                              ones(k,k), ...
                              verbose_flag, ...
                              cgf_fn, ...
                              u_init, ...
                              C_cgf, ...
                              M_cgf);

    A_cgf(:,i)=u_cgf';
    u1=pinv_C_cgf*A_cgf(:,i);
    B1_cgf(i,:)=u1/(u1'*A_cgf(:,i));

    fprintf('CHF run\n');
    M_chf = eye(k) - A_chf*B1_chf;
    [u_chf,~,~,~] = ICA_power(X, ...
                              maxruns, ...
                              ones(k,k), ...
                              verbose_flag, ...
                              chf_fn, ...
                              u_init, ...
                              C_chf, ...
                              M_chf);

    A_chf(:,i)=u_chf';
    u1=pinv_C_chf*A_chf(:,i);
    B1_chf(i,:)=u1/(u1'*A_chf(:,i));

    fprintf('Kurtosis run\n');
    M_kurtosis = eye(k) - A_kurtosis*B1_kurtosis;
    [u_kurtosis,~,~,~] = ICA_power(X, ...
                                   maxruns, ...
                                   ones(k,k), ...
                                   verbose_flag, ...
                                   kurtosis_fn, ...
                                   u_init, ...
                                   C_kurtosis, ...
                                   M_kurtosis);

    A_kurtosis(:,i)=u_kurtosis';
    u1=pinv_C_kurtosis*A_kurtosis(:,i);
    B1_kurtosis(i,:)=u1/(u1'*A_kurtosis(:,i));

    fprintf('=====================\n');
    fprintf('=====================\n');
end

fprintf('Meta run\n');

ind_scores_chf      = Delta_Score_Total(X,C_chf,A_chf,random_projections,num_trials);
ind_scores_cgf      = Delta_Score_Total(X,C_cgf,A_cgf,random_projections,num_trials);
ind_scores_kurtosis = Delta_Score_Total(X,C_kurtosis,A_kurtosis,random_projections,num_trials);
ind_scores_jade     = Delta_Score_Total(X,A_jade*A_jade',A_jade,random_projections,num_trials);
ind_scores_fastica  = Delta_Score_Total(X,A_fastica*A_fastica',A_fastica,random_projections,num_trials);

fprintf("Independence Score CHF : %.10f\n", ind_scores_chf);
fprintf("Independence Score CGF : %.10f\n", ind_scores_cgf);
fprintf("Independence Score Kurtosis : %.10f\n", ind_scores_kurtosis);
fprintf("Independence Score JADE : %.10f\n", ind_scores_jade);
fprintf("Independence Score FASTICA : %.10f\n", ind_scores_fastica);

A_matrices    = {A_cgf, A_chf, A_kurtosis, A_jade, A_fastica};
scores        = [ind_scores_cgf, ind_scores_chf, ind_scores_kurtosis, ...
                 ind_scores_jade, ind_scores_fastica];
[~,min_index] = min(scores);
A_meta = A_matrices{min_index};

%% Orthogonalize matrices
% A_cgf      = orth(A_cgf);
% A_chf      = orth(A_chf);
% A_kurtosis = orth(A_kurtosis);
% A_jade     = orth(A_jade);
% A_fastica  = orth(A_fastica);
% A_meta     = orth(A_meta);

% A_cgf      = normr(A_cgf);
% A_chf      = normr(A_chf);
% A_kurtosis = normr(A_kurtosis);
% A_jade     = normr(A_jade);
% A_fastica  = normr(A_fastica);
% A_meta     = normr(A_meta);

A_cgf      = A_cgf/norm(A_cgf);
A_chf      = A_chf/norm(A_chf);
A_kurtosis = A_kurtosis/norm(A_kurtosis);
A_jade     = A_jade/norm(A_jade);
A_fastica  = A_fastica/norm(A_fastica);
A_meta     = A_meta/norm(A_meta);
 
%% Denoise test data

figure;
h1 = subplot(2,4,1); imshow(original_image), title(h1, "Original Image"); hold on;
h2 = subplot(2,4,2); imshow(uint8(noisy_image)), title(h2, "Noisy Image"); hold on;

for j=1:5

    if(j == 1)
        img_title = strcat("CGF, Score = ", string(ind_scores_cgf));
        % A     = orth(A_cgf);
        A = A_cgf;
    elseif(j == 2)
        img_title = strcat("CHF, Score = ", string(ind_scores_chf));
        % A     = orth(A_chf);
        A = A_chf;
    elseif(j == 3)
        img_title = strcat("Kurtosis, Score = ", string(ind_scores_kurtosis));
        % A     = orth(A_kurtosis);
        A = A_kurtosis;
    elseif(j == 4)
        img_title = strcat("JADE, Score = ", string(ind_scores_jade));
        % A     = orth(A_jade);
        A = A_jade;
    else
        img_title = strcat("FastICA, Score = ", string(ind_scores_fastica));
        % A     = orth(A_fastica);
        A = A_fastica;
    end

    fprintf("Denoising with %s\n", img_title);

    inv_A = pinv(A);
    inv_A = inv_A/sqrtm(inv_A'*inv_A);
    A     = pinv(inv_A);

    denoised_test_images = zeros(x_num_windows*y_num_windows, window_size*window_size);
    itr = 1;
    for x=1:x_num_windows
        for y=1:y_num_windows
            fprintf("x = %d, y = %d\n", x, y);
            x_test          = reshape(noisy_image(window_size*(x-1)+1:window_size*x, ...
                                                  window_size*(y-1)+1:window_size*y), ...
                                                  [1,window_size*window_size]);
            x_test          = double(x_test);
            x_test          = x_test - mean_X;
            s_test          = inv_A*x_test';
            % s_test          = sign(s_test).*max(0, (abs(s_test) - 25)/2 + 0.5*sqrt( (abs(s_test) + 25).^2 - 4*10 ) );
            s_test          = sign(s_test).*max(0, abs(s_test)-25);
            s_test          = real(s_test);
            denoised_x_test = (A*s_test)' + mean_X;
            denoised_test_images(itr,:) = denoised_x_test;
            itr = itr + 1;
        end
    end

    indices = 1:100;

    itr = 1;
    denoised_image = zeros(512,512);
    for x=1:x_num_windows
        for y=1:y_num_windows
            original_flattened_subimg = reshape(original_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
            flattened_subimg          = denoised_test_images(itr,:);
            itr                       = itr + 1;
            denoised_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y) = reshape(flattened_subimg, [window_size,window_size]);
        end
    end

    hj = subplot(2,4,j+2); imshow(uint8(denoised_image)), title(hj, img_title); hold on;
end

