function [valid, eval_metrics, dataset] = play_active_learning_new(metrics,choices,choices_gt,ratio,method,config)
% this function run the active learning algorithm on [ratio] on the original
% cells based using [method]
% [INPUT]
%   [metrics]    : the original features for all cells based on feature
%                  engineering (76 x num_cells).
%   [choices]    : the labels from one annotator. Same as choice_gt is we 
%                  are not comparing with the consensus labelings (1 x num_cells)
%   [choices_gt] : the groundtruth labeling (1 x num_cells).
%   [ratio]      : the ratio of cell sorted with active learning (int).
%   [method]     : query algorithm (struct)
%       - name : [random, algo-rank, cal, dcal, multi-arm] (string) 
%       - (algorithm specific params)
%           -- dcal      : weight
%           -- multi-arm : wt, gamma, reward_func, reward_name
%       - continue   : if true, then fine tune. if false train from scratch.
%       - pretrained : pretrained dataset with additional field [threshold]
%       - mdl        : pretrained mdl.
%   [config]     : data preprocessing configuration
%       - DO_ZSCORING  : true then z-scored the input
%       - n            : number of cells returned from query algorithm for
%                        user to label
%       - repeat       : number of repeat iteration for random 3 good 3 bad
%                        initialization.
% [OUTPUT]
%   [valid]   : True if all final predictions has probability > 0.5, False
%               otherwise (bool).
%   [eval_metric] : (struct) includes all evaluation metrics for every 1%
%                   step. 
%       - ACC : accuracy
%       - TPR : true positive rate
%       - TNR : true negative rate
%       - Precision : precision score
%       - Recall    : recall score
%       - Fscore    : F-score
%       - ROC       : ROC curve
%   [dataset] : the final dataset (struct). 
%       - features       : N x d
%       - labels_ex      : N x 1 human / expert labels
%       - labels_ml      : N x 1 cell classifier / ML labels 
%       - labels_ml_prob : N x 1 ML labels corresponding probablities

global labels_saved;

if ~isfield(config, 'lam'), config.lam = "auto"; end
if ~isfield(method, 'continue'), method.continue = 0; end
if ~isfield(config, 'balance'), config.balance = 0; end
if ~isfield(config, 'balance_pretrained'), config.balance_pretrained = 0; end
if ~isfield(config, 'align'), config.align = 0; end

features  = metrics';
labels    = choices';
labels_gt = choices_gt';
% DEFINE NUMBER OF SORTED CELLS
num_cells = size(metrics,2);
stop_cell = floor(ratio*num_cells);
p1percent_cell = floor(0.001*num_cells);

% randomize the inital 3 good 3 bad cells
eval_lst = cell(1,config.repeat);
for nexp = 1:config.repeat
    % INITIALIZE DATASET
    if ~(method.continue)
        dataset = initialization(features, config);
    else
        dataset = initialization(features, config, method.mdl);
    end
    [method, dataset] = parse_method(method, dataset);
    % INITIALIZE EVALUATION METRICS
    eval_metrics = init_eval_metrics();
    % INITIALIZE CELL CLASSIFIER
    if ~(method.continue)
        labels_1_indices = find(labels == 1);
        labels_minus1_indices = find(labels == -1);
        q_iscells = randsample(labels_1_indices, 3);
        q_nocells = randsample(labels_minus1_indices, 3);
        dataset.labels_ex(q_iscells) = labels(q_iscells);
        dataset.labels_ex(q_nocells) = labels(q_nocells);
        [dataset, mdl] = train_classifier(dataset,config.lam);
    else
        mdl = dataset.mdl;
        pretrained = method.pretrained;
        pretrained.balance = config.balance_pretrained;
        % align the distribution between dataset and pretrained
        if config.align
            dataset_feature_aligned = align_features(pretrained.features, dataset.features);
            dataset.features = dataset_feature_aligned;
        end

        labels_1_indices = find(labels == 1);
        labels_minus1_indices = find(labels == -1);
        q_iscells = randsample(labels_1_indices, 3);
        q_nocells = randsample(labels_minus1_indices, 3);
        dataset.labels_ex(q_iscells) = labels(q_iscells);
        dataset.labels_ex(q_nocells) = labels(q_nocells);

        [dataset, ~ , mdl] = fine_tune(dataset, pretrained);
    end
    % ACTSORT
    for i=1:stop_cell
        % select next data to be labeled
        [q_idxs, ~, dataset] = step_al(dataset, config.n, mdl, method.name);
        %%%%%%%%
        labels_saved.labels_ex      = [labels_saved.labels_ex; dataset.labels_ex'];
        labels_saved.labels_ml_prob = [labels_saved.labels_ml_prob; dataset.labels_ml_prob'];
        labels_saved.q_idxs         = [labels_saved.q_idxs; q_idxs];
        %%%%%%%%
        % label the data
        label = labels_gt(q_idxs);
        % add the label to the training dataset
        dataset = annotate(dataset, q_idxs, label);
        % train the classifier
        if ~method.continue
            [dataset, mdl] = train_classifier(dataset,config.lam);
        else
            [dataset, pretrained, mdl] = fine_tune(dataset, pretrained);
        end
        if i==1 || mod(i, p1percent_cell) == 0
            eval_metrics = get_accuracy(dataset, labels_gt, eval_metrics);
        end
    end
    eval_lst{nexp} = eval_metrics;
    fprintf("repeat %i>>>", nexp);
    %fprintf('Random initialization %i has accuracy %.2f%% and AUC %.2f\n', nexp, eval_metrics.ACC(end)*100, eval_metrics.AUC(end));
end
clear eval_metrics;
eval_metrics = avg_eval_lst(eval_lst);
valid = dataset.labels_ml_prob;
end
