clc; clear;
addpath(genpath('./FastICA_25/'));

%% Load data

dataset = load("mnist-with-awgn.mat");

%% Generate data by considering sub-images

window_size   = 7;
x_num_windows = 28/window_size;
y_num_windows = 28/window_size;

num_train   = floor(length(dataset.train_x)/10);
noise_power = 1;
k           = 784;
U           = randn(k,k);
Sigma       = noise_power*(1/k)*(U*U');
blocksigma  = zeros(k,k);
Uprime      = randn(5,5);
blocksigma(k/2-2:k/2+2, k/2-2:k/2+2) = (1/5)*(Uprime*Uprime');
Sigma       = Sigma + blocksigma;
X           = zeros(num_train*x_num_windows*y_num_windows, window_size*window_size);
itr         = 1;
mkdir("TrainingMNIST");
for i=1:num_train
    image  = dataset.train_x(i,:);
    label  = find(dataset.train_y(i,:) == 1) - 1;
    % e      = mvnrnd(zeros(k,1), Sigma, 1);
    % image  = double(image) + e;
    image  = reshape(double(image), [28,28])';
    % figure;imagesc(image);
    image(10:20,10:20) = image(10:20,10:20) + 10*randn(11,11);
    for x=1:x_num_windows
        for y=1:y_num_windows
            X(itr, :) = reshape(image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
            itr = itr + 1;
        end
    end
    if(mod(i,10) == 0)
        fprintf("%d images processed \n", i);
    end
    subfolder = strcat("TrainingMNIST/", string(label));
    mkdir(subfolder);
    filename = strcat(subfolder, "/", string(i), ".png");
    imwrite(uint8(image), filename, "png");
end

%% Perform PCA to reduce dimensionality 

mean_X               = mean(X,1);
X                    = X - mean_X;
cov                  = cov(X);
[U,Lambda]           = eig(cov);
U                    = real(U);
Lambda               = real(Lambda);
[Lambda, ind]        = sort(diag(Lambda));
U                    = U(:, ind);
k_pca                = 25;
k_ica                = 25;
principal_components = U(:,end-k_pca+1:end);
new_X                = X*principal_components;

%% Run algorithms

A_jade                = jade(new_X',k_ica); fprintf("JADE completed\n");
[A_fastica,~]         = fastica(new_X','numOfIC',k_ica); fprintf("FASTICA completed\n");
cgf_fn                = cgf(new_X);
chf_fn                = symmetric_chf(new_X);
kurtosis_fn           = kurtosis(new_X);
C_cgf                 = A_jade*A_jade';
C_chf                 = A_jade*A_jade';
C_kurtosis            = kurtosis_fn.estimate_C(20);
pinv_C_cgf            = pinv(C_cgf);
pinv_C_chf            = pinv(C_chf);
pinv_C_kurtosis       = pinv(C_kurtosis);

A_cgf                 = zeros(k_pca,k_ica);
B1_cgf                = zeros(k_ica,k_pca);
A_chf                 = zeros(k_pca,k_ica);
B1_chf                = zeros(k_ica,k_pca);
A_kurtosis            = zeros(k_pca,k_ica);
B1_kurtosis           = zeros(k_ica,k_pca);
A_meta                = zeros(k_pca,k_ica);
B1_meta               = zeros(k_ica,k_pca);
verbose_flag          = 0;
maxruns               = 20;
num_trials            = 200;
random_projections    = randn(2*num_trials,k_ica);

for i=1:k_ica
    fprintf('Column %d\n',i);
    u_init = randn(k_pca,1);
    
    fprintf('CGF run\n');
    M_cgf = eye(k_pca) - A_cgf*B1_cgf;
    [u_cgf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              cgf_fn, ...
                              u_init, ...
                              C_cgf, ...
                              M_cgf);

    A_cgf(:,i)=u_cgf';
    u1=pinv_C_cgf*A_cgf(:,i);
    B1_cgf(i,:)=u1/(u1'*A_cgf(:,i));

    fprintf('CHF run\n');
    M_chf = eye(k_pca) - A_chf*B1_chf;
    [u_chf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              chf_fn, ...
                              u_init, ...
                              C_chf, ...
                              M_chf);

    A_chf(:,i)=u_chf';
    u1=pinv_C_chf*A_chf(:,i);
    B1_chf(i,:)=u1/(u1'*A_chf(:,i));

    fprintf('Kurtosis run\n');
    M_kurtosis = eye(k_pca) - A_kurtosis*B1_kurtosis;
    [u_kurtosis,~,~,~] = ICA_power(new_X, ...
                                   maxruns, ...
                                   ones(k_pca,k_ica), ...
                                   verbose_flag, ...
                                   kurtosis_fn, ...
                                   u_init, ...
                                   C_kurtosis, ...
                                   M_kurtosis);

    A_kurtosis(:,i)=u_kurtosis';
    u1=pinv_C_kurtosis*A_kurtosis(:,i);
    B1_kurtosis(i,:)=u1/(u1'*A_kurtosis(:,i));

    fprintf('=====================\n');
    fprintf('=====================\n');
end

fprintf('Meta run\n');

ind_scores_chf      = Delta_Score_Total(new_X,C_chf,A_chf,random_projections,num_trials);
ind_scores_cgf      = Delta_Score_Total(new_X,C_cgf,A_cgf,random_projections,num_trials);
ind_scores_kurtosis = Delta_Score_Total(new_X,C_kurtosis,A_kurtosis,random_projections,num_trials);
ind_scores_jade     = Delta_Score_Total(new_X,A_jade*A_jade',A_jade,random_projections,num_trials);
ind_scores_fastica  = Delta_Score_Total(new_X,A_fastica*A_fastica',A_fastica,random_projections,num_trials);

fprintf("Independence Score CHF : %.10f\n", ind_scores_chf);
fprintf("Independence Score CGF : %.10f\n", ind_scores_cgf);
fprintf("Independence Score Kurtosis : %.10f\n", ind_scores_kurtosis);
fprintf("Independence Score JADE : %.10f\n", ind_scores_jade);
fprintf("Independence Score FASTICA : %.10f\n", ind_scores_fastica);

A_matrices    = {A_cgf, A_chf, A_kurtosis, A_jade, A_fastica};
scores        = [ind_scores_cgf, ind_scores_chf, ind_scores_kurtosis, ...
                 ind_scores_jade, ind_scores_fastica];
[~,min_index] = min(scores);
A_meta = A_matrices{min_index};
ind_scores_meta = scores(min_index);
fprintf("Independence Score Meta : %.10f\n", ind_scores_meta);

A_cgf      = normr(A_cgf);
A_chf      = normr(A_chf);
A_kurtosis = normr(A_kurtosis);
A_jade     = normr(A_jade);
A_fastica  = normr(A_fastica);
A_meta     = normr(A_meta);

% %% Orthogonalize matrices
% A_cgf      = orth(A_cgf);
% A_chf      = orth(A_chf);
% A_kurtosis = orth(A_kurtosis);
% A_jade     = orth(A_jade);
% A_fastica  = orth(A_fastica);
% A_meta     = orth(A_meta);

%% Denoise test data

for j=1:6

    if(j == 1)
        img_title = "CGF";
        A     = A_cgf;
    elseif(j == 2)
        img_title = "CHF";
        A     = A_chf;
    elseif(j == 3)
        img_title = "Kurtosis";
        A     = A_kurtosis;
    elseif(j == 4)
        img_title = "JADE";
        A     = A_jade;
    elseif(j == 5)
        img_title = "FastICA";
        A     = A_fastica;
    else
        img_title = "Meta";
        A     = A_meta;
    end

    % mkdir(img_title);
    inv_A = pinv(A);
    inv_A = inv_A*(inv_A'*inv_A);
    A     = pinv(inv_A);

    X_test_images        = dataset.test_x;
    Y_test_images        = dataset.test_y;
    denoised_test_images = zeros(10000*x_num_windows*y_num_windows, window_size*window_size);
    itr = 1;
    for i=1:100
        image  = dataset.test_x(i,:);
        % e      = mvnrnd(zeros(k,1), Sigma, 1);
        % image  = double(image) + e;
        image  = reshape(double(image), [28,28])';
        % image(12:16,12:16) = image(12:16,12:16) + 10*randn(5,5);
        image(10:20,10:20) = image(10:20,10:20) + 10*randn(11,11);
        % figure;
        % imshow(image);
        % title("Noisy image");
        % figure;
        % imshow(reshape(dataset.test_x(i,:), [28,28])');
        % title("Original image");
        % keyboard;
        for x=1:x_num_windows
            for y=1:y_num_windows
                x_test          = reshape(image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
                x_test          = double(x_test);
                x_test          = x_test - mean_X;
                s_test          = inv_A*(x_test*principal_components)';
                % s_test          = sign(s_test).*max(0, (abs(s_test) - 25)/2 + 0.5*sqrt( (abs(s_test) + 25).^2 - 4*10 ) );
                s_test          = sign(s_test).*max(0, abs(s_test)-50);
                s_test          = real(s_test);
                denoised_x_test = (A*s_test)'*principal_components' + mean_X;
                denoised_test_images(itr,:) = denoised_x_test;
                itr = itr + 1;
            end
        end
        if(mod(i,1000) == 0)
            fprintf("%d test images denoised \n", i);
        end
    end

    %% See example denoised image

    % indices = 1:length(X_test_images);
    indices = 1:100;
    counter = 1;
    squared_error = 0;
    for i=1:length(indices)
        index           = indices(i);
        denoised_index  = x_num_windows*y_num_windows*(index-1);
        original_image  = double(X_test_images(index,:));
        label           = find(Y_test_images(index,:) == 1) - 1;
        % original_image  = original_image + e;
        original_image  = reshape(original_image, [28,28])';

        itr = 1;
        denoised_image = zeros(28,28);
        for x=1:x_num_windows
            for y=1:y_num_windows
                original_flattened_subimg = reshape(original_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
                % flattened_subimg          = uint8(denoised_test_images(denoised_index+itr,:));
                flattened_subimg          = denoised_test_images(denoised_index+itr,:);
                itr                       = itr + 1;
                denoised_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y) = reshape(flattened_subimg, [window_size,window_size]);
            end
        end
        counter = counter + 1;

        % figure;
        % subplot(1,2,1); imshow(uint8(original_image));
        % subplot(1,2,2); imshow(uint8(denoised_image));
        x = uint8(original_image);
        y = uint8(denoised_image);
        squared_error = squared_error + sum((x(:) - y(:)).^2);
        % fprintf("%s squared-error : %.5f\n", img_title, sum((x(:) - y(:)).^2));
        % title(img_title);

        subfolder = strcat(img_title, "/", string(label));
        mkdir(subfolder);

        filename = strcat(subfolder, "/", string(counter),".png");
        imwrite(uint8(denoised_image), filename, "png");
    end

    fprintf("%s squared-error : %.5f\n", img_title, squared_error/counter);

end

% fprintf("%s squared-error : %.5f\n", img_title, sum((x(:) - y(:)).^2));


