clearvars;
close all;

directory = 'tmp';  % use 'tmp' or 'std' depending on which experiment you want to see results for
datasets = {'gsm8k', 'openai_humaneval', 'natural_reasoning', 'mbpp', 'drop'};
for i = 1:length(datasets)
    do_dataset(datasets{i}, directory);
end

function do_dataset(dataset, directory)

listing = dir(directory);
fprintf('Dataset: %s\n=============================\n', dataset);
for i = 1:length(listing)
    if contains(listing(i).name, dataset)
        fname = [directory, '/', listing(i).name];
        data = readtable(fname);
        configs = unique(data.config_name);
        break
    end
end

be_mat = zeros(length(configs), 1);
tr_mat = zeros(length(configs), 1);
num_seeds = 0;
for i = 1:length(listing)
    if contains(listing(i).name, dataset)
        num_seeds = num_seeds + 1;
        fname = [directory, '/', listing(i).name];
        data = readtable(fname);

        if strcmp(directory, 'tmp')
            baseline_name = 'single_draft_5_1';
        else
            baseline_name = 'single_draft_4_1';
        end
        data_subset = data(strcmp(data.config_name, baseline_name), :);
        baseline_tr = data_subset.token_rate;

        for j = 1:length(configs)
            data_subset = data(strcmp(data.config_name, configs{j}), :);
            be_mat(j, num_seeds) = mean(data_subset.acceptance_rate);
            tr_mat(j, num_seeds) = mean(100 * (data_subset.token_rate - baseline_tr) ./ baseline_tr);
        end
    end
end

std_be = std(be_mat, 0, 2) / sqrt(num_seeds);
std_tr = std(tr_mat, 0, 2) / sqrt(num_seeds);
be_mat = mean(be_mat, 2);
tr_mat = mean(tr_mat, 2);

for i = 1:length(configs)
    fprintf('%16s  BE: %4.2f+-%4.2f  TR: %5.2f+-%4.2f\n', configs{i}, ...
        be_mat(i), std_be(i), tr_mat(i), std_tr(i));
end

end
