clear;clc;
%%
load('./results/sort50_ebrahimi.mat')
clearvars -except eval_lst methods
[results] = compute_evaluation_metrics(eval_lst);
dspname = {};
for i=1:12
    method = methods{i};
    dspname{i} = get_legend_name(method);
end
clearvars -except methods results
%%
method_indices = [1,2,3,6,8,10,12];

color_map = [0,      0.4470, 0.7410; % blue
             0.8500, 0.3250, 0.0980; % orange
             0.9290, 0.6940, 0.1250; % yellow
             0.4940, 0.1840, 0.5560; % purple
             0.4660, 0.6740, 0.1880; % green
             0.6350, 0.0780, 0.1840; % red
             0.3010 0.7450 0.9330];  % light blue
% Dims of results: num_methods x lam (lam == 4 is the auto option) 
% x annotators x 558 (up to 50%)
lam_all = [1e-3,1e-2,1e-1,1,10];
ratio_lst = [10, 50, 200];
% x = linspace(1,50,H);
x = lam_all;

for r=1:3
    ratio = ratio_lst(r);
    ACC = squeeze(results.ACC(:,:,:,ratio));
    TPR = squeeze(results.TPR(:,:,:,ratio));
    TNR = squeeze(results.TNR(:,:,:,ratio));
    AUC = squeeze(results.AUC(:,:,:,ratio));
    Precision = squeeze(results.Precision(:,:,:,ratio));
    Recall = squeeze(results.Recall(:,:,:,ratio));
    Fscore = squeeze(results.Fscore(:,:,:,ratio));

    avg_acc = squeeze(mean(ACC,3));
    avg_tpr = squeeze(mean(TPR,3));
    avg_tnr = squeeze(mean(TNR,3));
    avg_precision = squeeze(mean(Precision,3));
    avg_recall = squeeze(mean(Recall,3));
    avg_fscore = squeeze(mean(Fscore,3));
    avg_auc = squeeze(mean(AUC,3));

    ACC_std = squeeze(results.ACC_std(:,:,:,ratio));
    TPR_std = squeeze(results.TPR_std(:,:,:,ratio));
    TNR_std = squeeze(results.TNR_std(:,:,:,ratio));
    Precision_std = squeeze(results.Precision_std(:,:,:,ratio));
    Recall_std = squeeze(results.Recall_std(:,:,:,ratio));
    Fscore_std = squeeze(results.Fscore_std(:,:,:,ratio));
    AUC_std = squeeze(results.AUC_std(:,:,:,ratio));
    
    avg_acc_std = squeeze(mean(ACC_std, 3));
    avg_tpr_std = squeeze(mean(TPR_std, 3));
    avg_tnr_std = squeeze(mean(TNR_std, 3));
    avg_precision_std = squeeze(mean(Precision_std, 3));
    avg_recall_std = squeeze(mean(Recall_std, 3));
    avg_fscore_std = squeeze(mean(Fscore_std, 3));
    avg_auc_std = squeeze(mean(AUC_std, 3));
    
    ratio = num2str(ratio);
    plot_metric_cv(avg_auc, avg_auc_std, methods, method_indices, x, color_map, strcat('AUC-',ratio), [0.6,1])
    plot_metric_cv(avg_acc, avg_acc_std, methods, method_indices, x, color_map, strcat('Accuracy-',ratio), [0.6,1])
    plot_metric_cv(avg_tpr, avg_tpr_std, methods, method_indices, x, color_map, strcat('True Positive Rate-',ratio), [0.6,1])
    plot_metric_cv(avg_tnr, avg_tnr_std, methods, method_indices, x, color_map, strcat('True Negative Rate-',ratio), [0.6,1])
%     plot_metric_cv(avg_precision, avg_precision_std, methods, method_indices, x, color_map, strcat('Precision-',ratio), [0.7,1])
%     plot_metric_cv(avg_recall, avg_recall_std, methods, method_indices, x, color_map, strcat('Recall-',ratio), [0.7,1])
%     plot_metric_cv(avg_fscore, avg_fscore_std, methods, method_indices, x, color_map, strcat('F-score-',ratio), [0.7,1])


end

function plot_metric_cv(avg_metric, std_metric, methods, method_indices, x, color_map, metric_name, ylim_range)
figure;
for i = 1:length(method_indices)
    k = method_indices(i);
    method = methods{k};
    dspname = get_legend_name(method);
    mean_y = avg_metric(k,:);
    std_y  = std_metric(k,:);
    semilogx(x, mean_y, 'DisplayName',dspname, 'Color',color_map(i,:), 'LineStyle','-', 'LineWidth',0.5)
    hold on
    % create shaded region
    shade_x = [x, fliplr(x)];
    shade_y = [mean_y + std_y, fliplr(mean_y - std_y)];
    fill(shade_x, shade_y,color_map(i,:),  'FaceAlpha', 0.2,'EdgeColor','none'); % Fill the region with cyan color and 30% transparency
    hold on
end

line([1, 1], ylim_range, 'Color', 'r', 'LineWidth', 0.5, 'LineStyle', '--'); % Red line, adjust properties as needed


ylim(ylim_range)


hold off
legend()
title(metric_name)

figName = "crossval_"+metric_name+".pdf";
exportgraphics(gcf,figName,'ContentType','vector')
end