%% MNIST: Comparison of Subsampling vs. Bootstrap Estimators
clc; clear; close all;
fprintf('\n===== MNIST Experiment: Test Data Estimation Comparison (Subsampling vs. Bootstrap) =====\n');
set(groot, 'DefaultAxesFontSize', 25);

%% Load MNIST Data
% The file 'mnist-matlab/mnist.mat' is expected to contain:
%   mnist_data.training (with fields 'images' and 'labels')
%   mnist_data.test (with fields 'images' and 'labels')
fprintf('Loading MNIST data...\n');
mnist_data = load('mnist-matlab/mnist.mat');  % Adjust path if needed

% Process training data
images_train = double(mnist_data.training.images);  
sz_train = size(images_train);
% Reshape so that each column is a flattened 28x28 image (784 features)
images_train = reshape(images_train, [], sz_train(end));
labels_train = mnist_data.training.labels;

% % Process test data
% images_test = double(mnist_data.test.images);
% sz_test = size(images_test);
% images_test = reshape(images_test, [], sz_test(end));
% labels_test = mnist_data.test.labels;
images_test = images_train;
labels_test = labels_train;

% Define classes (digits 0 through 9)
classes = 0:9;
num_classes = length(classes);

% Estimator & CI parameters.
alpha = 5;         % Learning rate parameter
verbose = 1;       % Verbose mode
z_val = 1.96;      % For 95% confidence intervals
b_boot = 100;      % Number of bootstrap samples

% Containers for storing results for each class
composite_images_sub = cell(num_classes,1);   % Composite images for subsampling
composite_images_boot = cell(num_classes,1);  % Composite images for bootstrap
gt_lines = cell(num_classes,1);       % Ground truth eigenvector (from training)
est_lines_sub = cell(num_classes,1);  % Estimated eigenvector from test (subsampling)
CI_lower_lines_sub = cell(num_classes,1);
CI_upper_lines_sub = cell(num_classes,1);
est_lines_boot = cell(num_classes,1); % Estimated eigenvector from test (bootstrap)
CI_lower_lines_boot = cell(num_classes,1);
CI_upper_lines_boot = cell(num_classes,1);
runtime_sub = zeros(num_classes,1);   % Runtime for subsampling estimator
runtime_boot = zeros(num_classes,1);  % Runtime for bootstrap estimator

%% Process Each Class
for i = 1:num_classes
    class_digit = classes(i);
    
    %% Ground Truth from Training Data (for this class)
    idx_train = (labels_train == class_digit);
    X_train = images_train(:, idx_train);
    X_train = X_train';  % each row is an image
    [n_train, d_train] = size(X_train);
    fprintf('Class %d (Train): %d samples, %d features.\n', class_digit, n_train, d_train);
    X_train = X_train - mean(X_train,1);
    Sigma_train = cov(X_train);
    [U_train, ~, ~] = svd(Sigma_train);
    gt_line = U_train(:,1);
    
    %% Estimation from Test Data (for this class)
    idx_test = (labels_test == class_digit);
    X_test = images_test(:, idx_test);
    X_test = X_test';  % each row is an image
    [n_test, d_test] = size(X_test);
    fprintf('Class %d (Test): %d samples, %d features.\n', class_digit, n_test, d_test);
    X_test = X_test - mean(X_test,1);
    
    % Set up test data parameters (dummy trueV since ground truth is from training)
    Sigma_test = cov(X_test);
    [U_test, S_test, ~] = svd(Sigma_test);
    eig_vals_test = diag(S_test);
    if length(eig_vals_test) > 1
        eigengap_test = eig_vals_test(1) - eig_vals_test(2);
    else
        eigengap_test = eig_vals_test(1);
    end
    data_params_test = struct();
    data_params_test.n = n_test;
    data_params_test.d = d_test;
    data_params_test.Sigma_true = Sigma_test;
    data_params_test.Sigma_true_sqrtm = sqrtm(Sigma_test);
    data_params_test.trueV = gt_line;
    data_params_test.eigengap = eigengap_test;
    
    % Estimator parameters:
    m1 = 3;
    m2 = floor(max(log(n_test), log(d_test)));
    B = floor(n_test/(m1*m2));
    
    %% Run Subsampling Estimator
    tic;
    result_sub = get_var_estimates_subsampling(X_test, n_test, d_test, alpha, data_params_test, m1, m2, B, verbose);
    runtime_sub(i) = toc;
    est_sub = result_sub.oja_vec;
    var_sub = result_sub.variance;  % vector of length d_test
    % Align estimated eigenvector with ground truth:
    est_sub = est_sub * sign(gt_line' * est_sub);
    % 95% confidence interval (subsampling)
    CI_lower_sub = est_sub - z_val * sqrt(var_sub(:));
    CI_upper_sub = est_sub + z_val * sqrt(var_sub(:));
    
    % Store line data (subsampling)
    gt_lines{i} = gt_line;
    est_lines_sub{i} = est_sub;
    CI_lower_lines_sub{i} = CI_lower_sub;
    CI_upper_lines_sub{i} = CI_upper_sub;
    
    %% Run Bootstrap Estimator (b = 100)
    tic;
    result_boot = get_var_estimates_bootstrap(X_test, n_test, d_test, alpha, data_params_test, b_boot, verbose);
    runtime_boot(i) = toc;
    est_boot = result_boot.oja_vec;
    var_boot = result_boot.variance;
    % Align estimated eigenvector with ground truth:
    est_boot = est_boot * sign(gt_line' * est_boot);
    % 95% confidence interval (bootstrap)
    CI_lower_boot = est_boot - z_val * sqrt(var_boot(:));
    CI_upper_boot = est_boot + z_val * sqrt(var_boot(:));
    
    % Store line data (bootstrap)
    est_lines_boot{i} = est_boot;
    CI_lower_lines_boot{i} = CI_lower_boot;
    CI_upper_lines_boot{i} = CI_upper_boot;
    
    fprintf('Class %d: Subsampling time = %.4f sec, Bootstrap time = %.4f sec\n', class_digit, runtime_sub(i), runtime_boot(i));
    
    %% Create Composite Images for this Class
    % Composite images: for each estimator, we generate a composite by displaying the test estimate
    % (with red tint proportional to uncertainty) alongside the training ground truth.
    img_size = sqrt(d_test);
    if floor(img_size) ~= img_size
        error('Number of features (%d) is not a perfect square.', d_test);
    end
    
    % --- Subsampling composite ---
    eig_img_sub = reshape(est_sub, img_size, img_size);
    CI_lower_img_sub = reshape(CI_lower_sub, img_size, img_size);
    CI_upper_img_sub = reshape(CI_upper_sub, img_size, img_size);
    uncert_img_sub = CI_upper_img_sub - CI_lower_img_sub;
    eig_norm_sub = mat2gray(eig_img_sub);
    uncert_norm_sub = mat2gray(uncert_img_sub);
    red_overlay = cat(3, ones(size(eig_norm_sub)), zeros(size(eig_norm_sub)), zeros(size(eig_norm_sub)));
    blend_factor = 0.7;
    alpha_mask_sub = blend_factor * uncert_norm_sub;
    eig_rgb_sub = repmat(eig_norm_sub, [1,1,3]);
    composite_test_sub = (1 - repmat(alpha_mask_sub, [1,1,3])).*eig_rgb_sub + repmat(alpha_mask_sub, [1,1,3]).*red_overlay;
    % Ground truth image (training)
    gt_img = reshape(gt_line, img_size, img_size);
    gt_norm = mat2gray(gt_img);
    gt_rgb = repmat(gt_norm, [1,1,3]);
    % Concatenate horizontally: left = test estimate (subsampling), right = training ground truth.
    composite_concat_sub = [composite_test_sub, gt_rgb];
    composite_images_sub{i} = composite_concat_sub;
    
    % --- Bootstrap composite ---
    eig_img_boot = reshape(est_boot, img_size, img_size);
    CI_lower_img_boot = reshape(CI_lower_boot, img_size, img_size);
    CI_upper_img_boot = reshape(CI_upper_boot, img_size, img_size);
    uncert_img_boot = CI_upper_img_boot - CI_lower_img_boot;
    eig_norm_boot = mat2gray(eig_img_boot);
    uncert_norm_boot = mat2gray(uncert_img_boot);
    alpha_mask_boot = blend_factor * uncert_norm_boot;
    eig_rgb_boot = repmat(eig_norm_boot, [1,1,3]);
    composite_test_boot = (1 - repmat(alpha_mask_boot, [1,1,3])).*eig_rgb_boot + repmat(alpha_mask_boot, [1,1,3]).*red_overlay;
    composite_concat_boot = [composite_test_boot, gt_rgb];
    composite_images_boot{i} = composite_concat_boot;
    
end

%% Figure 1: Composite Grid - Subsampling vs. Bootstrap
% We'll create two grids: one for the subsampling composites and one for the bootstrap composites.
figure;
t1 = tiledlayout(2, 5, 'TileSpacing', 'compact', 'Padding', 'compact');
for i = 1:num_classes
    nexttile;
    imshow(composite_images_sub{i});
    title(sprintf('Class %d', classes(i)));
end
sgtitle('MNIST: Subsampling Estimator (Test) vs. Ground Truth (Train)');
fprintf('Composite grid (subsampling) plotted.\n');

figure;
t2 = tiledlayout(2, 5, 'TileSpacing', 'compact', 'Padding', 'compact');
for i = 1:num_classes
    nexttile;
    imshow(composite_images_boot{i});
    title(sprintf('Class %d', classes(i)));
end
sgtitle('MNIST: Bootstrap Estimator (Test, b=100) vs. Ground Truth (Train)');
fprintf('Composite grid (bootstrap) plotted.\n');

%% Figure 2: Line Plots for All Classes
% For each class, plot the following on one axes:
%   - x-axis: coordinate index.
%   - Blue solid line: subsampling estimate (test), with shaded 95% CI (subsampling).
%   - Green solid line: bootstrap estimate (test), with shaded 95% CI (bootstrap).
%   - Red dashed line: training ground truth eigenvector.
% t3 = tiledlayout(2, 5, 'TileSpacing', 'compact', 'Padding', 'compact');
for i = 1:num_classes
    figure;
    nexttile;
    x_coords = 1:d_test;
    % Retrieve stored line data.
    gt_line = gt_lines{i};
    est_sub = est_lines_sub{i};
    CI_lower_sub = CI_lower_lines_sub{i};
    CI_upper_sub = CI_upper_lines_sub{i};
    est_boot = est_lines_boot{i};
    CI_lower_boot = CI_lower_lines_boot{i};
    CI_upper_boot = CI_upper_lines_boot{i};
    s = sort(abs(gt_line), 'descend');
    idx_gt = abs(gt_line) >= s(100);

    coverage_fraction = mean((gt_line(idx_gt) >= CI_lower_sub(idx_gt)) & (gt_line(idx_gt) <= CI_upper_sub(idx_gt)));
    fprintf('Fraction of coordinates covered by subsampling CI: %.2f%%\n', 100 * coverage_fraction);
    
    idx_gt = abs(gt_line) >= s(20);

    % Plot subsampling CI shading.
    x_fill = [x_coords, fliplr(x_coords)];
    y_fill_sub = [CI_lower_sub', fliplr(CI_upper_sub')];
    fill(x_fill, y_fill_sub, [0.2 0.2 0.8], 'EdgeColor', 'none', 'FaceAlpha', 0.5);
    hold on;
    % % Plot bootstrap CI shading.
    % y_fill_boot = [CI_lower_boot', fliplr(CI_upper_boot')];
    % fill(x_fill, y_fill_boot, [1 0.41 0.16], 'EdgeColor', 'none', 'FaceAlpha', 0.5);
    % hold on;
    % Plot subsampling estimate.
    % plot(x_coords, est_sub, 'b-', 'LineWidth', 2);
    % % Plot bootstrap estimate.
    % plot(x_coords, est_boot, 'g-', 'LineWidth', 2); hold on;
    % Plot ground truth.
    plot(x_coords(idx_gt), gt_line(idx_gt), 'ro', 'LineWidth', 2, 'MarkerSize', 6);
    
    xlabel('Coordinate Index');
    ylabel('Eigenvector Value');
    title(sprintf('95%% Confidence Interval and Offline Eigenvector for MNIST dataset Class %d', classes(i)));
    % legend('Sub CI','Boot CI','Sub Estimate','Boot Estimate','Ground Truth','Location','best');
    legend('Subsampling 95% CI','Top 20 coordinates','Location','best');
    % legend('95% CI','Boot Estimate','Ground Truth','Location','best');
    grid on;
    hold off;
end
% sgtitle('95% Confidence Interval and Offline Eigenvector');
fprintf('Line plots generated for all classes.\n');

%% Print Overall Runtime Information
fprintf('\nOverall Runtimes (per class):\n');
for i = 1:num_classes
    fprintf('Class %d: Subsampling = %.4f sec, Bootstrap = %.4f sec\n', classes(i), runtime_sub(i), runtime_boot(i));
end