clc; clear;
addpath(genpath('./FastICA_25/'));

%% Load data
filenames = ["flats.jpg", "bumper.jpg", "city.jpg", "raisin.jpg"];
% filenames = ["mnist_images/image1.jpg", "mnist_images/image2.jpg", "mnist_images/image3.jpg", "mnist_images/image4.jpg"];

[pics, file_names, npics, inverted, pic_size] = read_sources(filenames);

S = [];
for i = 1:npics
    s = pics{i}(:);
    S = [S, s];
    disp(['Range of value for ', file_names{i}, ': ', num2str(min(s)), ' -> ', num2str(max(s))]);
end

%% Reshape and plot original images
figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    mixed_image = reshape(S(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Original Image ', num2str(i)]);
end

n     = length(S);
k     = npics;
k_ica = npics;
k_pca = npics;

%% Mix images
noise_power = 1e-3;
U           = randn(npics,npics);
Sigma       = noise_power*(1/npics)*(U*U');
B           = randn(npics, npics);
e           = mvnrnd(zeros(npics,1), Sigma, n);

X = S * B' + e;
mean_X  = mean(X,1);
std_X   = std(X,1);
new_X       = (X - mean_X)./std_X;
data_covariance = cov(new_X);

%% Reshape and plot mixed images
figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    mixed_image = reshape(new_X(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Mixed Image ', num2str(i)]);
end

%% Run algorithms

A_jade                = jade(new_X',k_ica); fprintf("JADE completed\n");
[A_fastica,~]         = fastica(new_X','numOfIC',k_ica); fprintf("FASTICA completed\n");
cgf_fn                = cgf(new_X);
chf_fn                = symmetric_chf(new_X);
kurtosis_fn           = kurtosis(new_X);
C_cgf                 = A_jade*A_jade';
C_chf                 = A_jade*A_jade';
C_kurtosis            = kurtosis_fn.estimate_C(20);
pinv_C_cgf            = pinv(C_cgf);
pinv_C_chf            = pinv(C_chf);
pinv_C_kurtosis       = pinv(C_kurtosis);

A_cgf                 = zeros(k_pca,k_ica);
B1_cgf                = zeros(k_ica,k_pca);
A_chf                 = zeros(k_pca,k_ica);
B1_chf                = zeros(k_ica,k_pca);
A_kurtosis            = zeros(k_pca,k_ica);
B1_kurtosis           = zeros(k_ica,k_pca);
A_meta                = zeros(k_pca,k_ica);
B1_meta               = zeros(k_ica,k_pca);
verbose_flag          = 0;
maxruns               = 20;
num_trials            = 200;
random_projections    = randn(2*num_trials,k_ica);

for i=1:k_ica
    fprintf('Column %d\n',i);
    u_init = randn(k_pca,1);
    
    fprintf('CGF run\n');
    M_cgf = eye(k_pca) - A_cgf*B1_cgf;
    [u_cgf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              cgf_fn, ...
                              u_init, ...
                              C_cgf, ...
                              M_cgf);

    A_cgf(:,i)=u_cgf';
    u1=pinv_C_cgf*A_cgf(:,i);
    B1_cgf(i,:)=u1/(u1'*A_cgf(:,i));

    fprintf('CHF run\n');
    M_chf = eye(k_pca) - A_chf*B1_chf;
    [u_chf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              chf_fn, ...
                              u_init, ...
                              C_chf, ...
                              M_chf);

    A_chf(:,i)=u_chf';
    u1=pinv_C_chf*A_chf(:,i);
    B1_chf(i,:)=u1/(u1'*A_chf(:,i));

    fprintf('Kurtosis run\n');
    M_kurtosis = eye(k_pca) - A_kurtosis*B1_kurtosis;
    [u_kurtosis,~,~,~] = ICA_power(new_X, ...
                                   maxruns, ...
                                   ones(k_pca,k_ica), ...
                                   verbose_flag, ...
                                   kurtosis_fn, ...
                                   u_init, ...
                                   C_kurtosis, ...
                                   M_kurtosis);

    A_kurtosis(:,i)=u_kurtosis';
    u1=pinv_C_kurtosis*A_kurtosis(:,i);
    B1_kurtosis(i,:)=u1/(u1'*A_kurtosis(:,i));

    fprintf('=====================\n');
    fprintf('=====================\n');
end

fprintf('Meta run\n');

ind_scores_chf      = Delta_Score_Total(new_X,C_chf,A_chf,random_projections,num_trials);
ind_scores_cgf      = Delta_Score_Total(new_X,C_cgf,A_cgf,random_projections,num_trials);
ind_scores_kurtosis = Delta_Score_Total(new_X,C_kurtosis,A_kurtosis,random_projections,num_trials);
ind_scores_jade     = Delta_Score_Total(new_X,A_jade*A_jade',A_jade,random_projections,num_trials);
ind_scores_fastica  = Delta_Score_Total(new_X,A_fastica*A_fastica',A_fastica,random_projections,num_trials);

fprintf("Independence Score CHF : %.10f\n", ind_scores_chf);
fprintf("Independence Score CGF : %.10f\n", ind_scores_cgf);
fprintf("Independence Score Kurtosis : %.10f\n", ind_scores_kurtosis);
fprintf("Independence Score JADE : %.10f\n", ind_scores_jade);
fprintf("Independence Score FASTICA : %.10f\n", ind_scores_fastica);

A_matrices    = {A_cgf, A_chf, A_kurtosis, A_jade, A_fastica};
scores        = [ind_scores_cgf, ind_scores_chf, ind_scores_kurtosis, ...
                 ind_scores_jade, ind_scores_fastica];
[~,min_index] = min(scores);
A_meta = A_matrices{min_index};
ind_scores_meta = scores(min_index);
fprintf("Independence Score Meta : %.10f\n", ind_scores_meta);

%% Reshape and plot unmixed images

sinr_optimal_demixing = (A_chf')/data_covariance;
X_unmixed_chf = new_X*sinr_optimal_demixing';

figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    if( i == 1)
        X_unmixed_chf(:, i) = -1*X_unmixed_chf(:, i);
    end
    mixed_image = reshape(X_unmixed_chf(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Unmixed Image with CHF ', num2str(i)]);
end

sinr_optimal_demixing = (A_cgf')/data_covariance;
X_unmixed_cgf = new_X*sinr_optimal_demixing';

figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    mixed_image = reshape(X_unmixed_cgf(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Unmixed Image with CGF ', num2str(i)]);
end

sinr_optimal_demixing = (A_kurtosis')/data_covariance;
X_unmixed_kurtosis = new_X*sinr_optimal_demixing';

figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    mixed_image = reshape(X_unmixed_kurtosis(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Unmixed Image with Kurtosis ', num2str(i)]);
end

sinr_optimal_demixing = (A_jade')/data_covariance;
X_unmixed_jade = new_X*sinr_optimal_demixing';

figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    mixed_image = reshape(X_unmixed_jade(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Unmixed Image with JADE ', num2str(i)]);
end

sinr_optimal_demixing = (A_fastica')/data_covariance;
X_unmixed_fastica = new_X*sinr_optimal_demixing';

figure;
for i = 1:npics
    subplot(floor(npics/2), 2, i);
    mixed_image = reshape(X_unmixed_fastica(:, i), pic_size);
    imshow(mixed_image, []);
    title(['Unmixed Image with FastICA ', num2str(i)]);
end

% A_cgf      = normr(A_cgf);
% A_chf      = normr(A_chf);
% A_kurtosis = normr(A_kurtosis);
% A_jade     = normr(A_jade);
% A_fastica  = normr(A_fastica);
% A_meta     = normr(A_meta);