function [dataset, pretrained, mdl] = fine_tune(dataset, pretrained)
% This function train cell classifier (logistic regression)
% INPUT:
%   [dataset] a structure that contain fields
%       - features : N x d
%       - labels_ex : N x 1 human / expert labels
%       - labels_ml : N x 1 cell classifier / ML labels 
%       - labels_ml_prob : N x 1 ML labels equal to 1 probabily
%  [pretrained] previous labeled dataset
%       - features : N x d
%       - labels_ex : N x 1 human / expert labels
%       - labels_ml : N x 1 cell classifier / ML labels 
%       - labels_ml_prob : N x 1 ML labels equal to 1 probability
%       - threshold : a scalar define the throwaway threshold, i.e., if
%                     error > threshold, then throw away the sample in 
%                     the pretrained dataset. if threshold is -1, then 
%                     throw away the sample that has the maximum error.  
% OUTPUT:
%   [dataset] : modify the labeld_ml field in the input dataset
%   [mdl]     : the cell classifier, a logistic regression model

% binarize the dataset
labeled_idxs_this = find(dataset.labels_ex ~= 0);
labeled_ex01_this = dataset.labels_ex(labeled_idxs_this);

assert(sum(labeled_ex01_this==0)==0, 'there are unlabeled data included');

labeled_ex01_this(labeled_ex01_this==-1) = 0;
assert(all(labeled_ex01_this == 0 | labeled_ex01_this == 1), 'The array labeled_ex01 is not binary.');

% binarize the pretrained dataset
labeled_idxs_pretrain = find(pretrained.labels_ex ~= 0);
labeled_ex01_pretrain = pretrained.labels_ex(labeled_idxs_pretrain);

assert(sum(labeled_ex01_pretrain==0)==0, 'there are unlabeled data included');

labeled_ex01_pretrain(labeled_ex01_pretrain==-1) = 0;
assert(all(labeled_ex01_pretrain == 0 | labeled_ex01_pretrain == 1), 'The array labeled_ml01 is not binary.');

% features_cat     = [dataset.features(labeled_idxs_this,:); pretrained.features(labeled_idxs_pretrain, :)];
% labeled_exml01_cat = [labeled_ex01_this; labeled_ex01_pretrain];
features_this     = dataset.features(labeled_idxs_this,:);
features_pretrain = pretrained.features(labeled_idxs_pretrain, :);

% fine-tune
if ~pretrained.balance && ~dataset.balance
    features_cat     = [features_this; features_pretrain];
    labeled_exml01_cat = [labeled_ex01_this; labeled_ex01_pretrain];
    
    mdl = fitclinear(features_cat, labeled_exml01_cat, ...
            'learner', 'logistic', 'regularization', 'lasso',...
             'ClassNames', [0, 1], 'Prior', 'empirical','Lambda',"auto");	
elseif pretrained.balance && ~dataset.balance
    dataset_b.features = features_pretrain;
    dataset_b.labels   = labeled_ex01_pretrain;
    dataset_b.labels(dataset_b.labels==0) = -1;
    dataset_b = balance_pretrained_dataset(dataset_b);
    dataset_b.labels(dataset_b.labels==-1) = 0;

    features_cat       = [features_this; dataset_b.features];
    labeled_exml01_cat = [labeled_ex01_this; dataset_b.labels];
    mdl = fitclinear(features_cat, labeled_exml01_cat, ...
            'learner', 'logistic', 'regularization', 'lasso',...
             'ClassNames', [0, 1], 'Prior', 'empirical','Lambda',"auto");
elseif ~pretrained.balance && dataset.balance
    dataset_b.features = features_this;
    dataset_b.labels   = labeled_ex01_this;
    dataset_b.labels(dataset_b.labels==0) = -1;
    dataset_b = balance_pretrained_dataset(dataset_b);
    dataset_b.labels(dataset_b.labels==-1) = 0;
    
    features_cat       = [dataset_b.features; features_pretrain];
    labeled_exml01_cat = [dataset_b.labels; labeled_ex01_pretrain];
    mdl = fitclinear(features_cat, labeled_exml01_cat, ...
            'learner', 'logistic', 'regularization', 'lasso',...
             'ClassNames', [0, 1], 'Prior', 'empirical','Lambda',"auto");
else
    dataset_b.features = features_cat;
    dataset_b.labels   = labeled_ex01_cat;
    dataset_b.labels(dataset_b.labels==0) = -1;
    dataset_b = balance_pretrained_dataset(dataset_b);
    dataset_b.labels(dataset_b.labels==-1) = 0;

    mdl = fitclinear(dataset_b.features, dataset_b.labels, ...
            'learner', 'logistic', 'regularization', 'lasso',...
             'ClassNames', [0, 1], 'Prior', 'empirical','Lambda',"auto");
end

[pred, pred_probs] = predict(mdl, dataset.features);

dataset.labels_ml_prob = pred_probs(:,2);
dataset.labels_ml = pred;
dataset.labels_ml(dataset.labels_ml==0) = -1; % because we label it in -1 and 1

dataset.mdl = mdl;

if sum(pretrained.labels_ex ~= 0) ~= 0    
    % throw away a wrong sample from the pretrained dataset
    threshold = pretrained.threshold;
    
    [~, pred_probs] = predict(mdl, pretrained.features(labeled_idxs_pretrain, :));
    pred1_probs = pred_probs(:,2);
    errors = zeros(size(pred,1), 1);
    errors(labeled_ex01_pretrain==1) = 1 - pred1_probs(labeled_ex01_pretrain==1);
    errors(labeled_ex01_pretrain==0) = pred1_probs(labeled_ex01_pretrain==0);
    
    if threshold == -1
        [~, maxerrorIdx] = max(errors);
        idx_throwaway = labeled_idxs_pretrain(maxerrorIdx);
    else
        idx_throwaway = labeled_idxs_pretrain(errors>threshold);
        if length(idx_throwaway) < 1
            [~, maxerrorIdx] = max(errors);
            idx_throwaway = labeled_idxs_pretrain(maxerrorIdx);
        end
    end
    pretrained.labels_ex(idx_throwaway) = 0;
    % num_ex = sum(pretrained.labels_ex~=0);
    % fprintf("number of data from dataset 1 %i...", num_ex)
end
end