%/usr/bin/env matlab
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Created on 14:49, May. 15th, 2023
% 
% @author: Anonymous
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [acc_pred]=lasso_glm(params, X, y)
    % Use the inner `lassoglm` function to fit the given brain recordings.
    %
    % Args:
    %     params: struct - Model parameters initialized by lasso_glm_params.
    %     X: (3[tuple],) - The input data (X_train, X_validation, X_test),
    %         each item is of shape (n_samples, seq_len, n_channels).
    %     y: (3[tuple],) - The target labels (y_train, y_validation, y_test), each item is of shape (n_samples,).
    %
    % Returns:
    %     acc_pred: (2, seq_len) - The predicted accuracy of ecah lasso classifier trained
    %         at each time point. The first row is validation-accuracy, and the second row is test-accuracy.

    %% Prepare data & parameters for lasso glm.
    % Initialize train-set & validation-set & test-set from `X` & `y`.
    X_train = X{1}; X_validation = X{2}; X_test = X{3}; y_train = y{1}; y_validation = y{2}; y_test = y{3};
    % Initialize `seq_len`, we train lasso classifier for each time point.
    [~,seq_len,~] = size(X_train);
    assert(size(X_train, 2) == size(X_validation, 2));
    assert(size(X_train, 2) == size(X_test, 2));
    % Initialize `n_labels`, we train `n_labels` binomial lasso classifier for each time point.
    labels = unique(y_train); n_labels = length(labels);
    assert(length(unique(y_train)) == length(unique(y_validation)));
    assert(length(unique(y_train)) == length(unique(y_test)));
    % Initialize other parameters for lasso glm.
    l1_penalty = params.l1_penalty; l2_penalty = params.l2_penalty;

    %% Correct data for lasso glm.
    % TODO: Execute artifect correction.

    %% Execute lasso glm training.
    % Initialize `acc_pred` as nans.
    acc_pred = nan(2, seq_len);
    % Train lasso glm for each time point.
    for time_idx = 1:seq_len
        % Select `X_i` for current time point.
        % X_i - (n_samples, n_channels)
        X_train_i = squeeze(X_train(:,time_idx,:));
        X_validation_i = squeeze(X_validation(:,time_idx,:));
        X_test_i = squeeze(X_test(:,time_idx,:));
        % Scale `X_i` to improve the generalization ability of lasso glm.
        % Note: percentile-normalization is better than std-normalzation.
        X_train_i = X_train_i ./ prctile(abs(X_train_i), 95);
        X_validation_i = X_validation_i ./ prctile(abs(X_validation_i), 95);
        X_test_i = X_test_i ./ prctile(abs(X_test_i), 95);

        % Initialize `pred_i` as nans.
        % pred_i - (n_labels, n_samples)
        pred_validation_i = nan(n_labels, size(y_validation, 1)); pred_test_i = nan(n_labels, size(y_test, 1));
        % Train lasso glm for current time point.
        for label_train_idx = 1:n_labels
            % Fit lasso glm for current label using train-set.
            [W_i,fitinfo_i] = lassoglm(X_train_i, (y_train == labels(label_train_idx)), 'binomial', 'Standardize', false, ...
                'Alpha', (l1_penalty / (2 * l2_penalty + l1_penalty)), 'Lambda', (2 * l2_penalty + l1_penalty));
            % Test the fitted lasso glm on validation-set & test-set.
            pred_validation_i(label_train_idx,:) = 1 ./ (1 + exp(-(X_validation_i * W_i + fitinfo_i.Intercept)));
            pred_test_i(label_train_idx,:) = 1 ./ (1 + exp(-(X_test_i * W_i + fitinfo_i.Intercept)));
        end
        % Calculate `acc_i` from `pred_i`.
        if strcmp(params.acc_mode, 'default')
            % Calculate the accuracy on validation-set.
            % y_pred_validation - (n_samples, 1)
            [~,y_pred_validation_idxs] = max(pred_validation_i, [], 1);
            y_pred_validation_idxs = reshape(y_pred_validation_idxs, [], 1);
            y_pred_validation = labels(y_pred_validation_idxs);
            acc_pred(1,time_idx) = mean(y_validation == y_pred_validation);
            % Calculate the accuracy on test-set.
            % y_pred_test - (n_samples, 1)
            [~,y_pred_test_idxs] = max(pred_test_i, [], 1);
            y_pred_test_idxs = reshape(y_pred_test_idxs, [], 1);
            y_pred_test = labels(y_pred_test_idxs);
            acc_pred(2,time_idx) = mean(y_test == y_pred_test);
        elseif strcmp(params.acc_mode, 'lvbj')
            % Calculate the accuracy on validation-set.
            % classifier_ratio_validation - (n_labels, n_labels)
            classifier_ratio_validation = nan(n_labels, n_labels);
            for label_validation_idx = 1:n_labels
                % Note: We have assumed that each category has balanced samples!
                classifier_ratio_validation(:,label_validation_idx) = ...
                    mean(pred_validation_i(:,(y_validation == labels(label_validation_idx))), 2);
            end
            % classifier_validation_idxs - (n_labels, 1)
            [~,classifier_validation_idxs] = max(classifier_ratio_validation, [], 1);
            acc_pred(1,time_idx) = mean(classifier_validation_idxs == 1:n_labels);
            % Calculate the accuracy on test-set.
            % classifier_ratio_test - (n_labels, n_labels)
            classifier_ratio_test = nan(n_labels, n_labels);
            for label_test_idx = 1:n_labels
                % Note: We have assumed that each category has balanced samples!
                classifier_ratio_test(:,label_test_idx) = ...
                    mean(pred_test_i(:,(y_test == labels(label_test_idx))), 2);
            end
            % classifier_test_idxs - (n_labels, 1)
            [~,classifier_test_idxs] = max(classifier_ratio_test, [], 1);
            acc_pred(2,time_idx) = mean(classifier_test_idxs == 1:n_labels);
        end
    end
end

