clear;
close all;

addpath('/Applications/CPLEX_Studio1210/cplex/matlab/x86-64_osx');
addpath('/Applications/CPLEX_Studio1210/cplex/examples/src/matlab');
%addpath('C:\Program Files\IBM\ILOG\CPLEX_Studio1210\cplex\matlab\x64_win64');
%addpath('C:\Program Files\IBM\ILOG\CPLEX_Studio1210\cplex\examples\src\matlab');

randomSeed=12345;
rng(randomSeed);

% Input of loadData is either 'credit', 'adult', 'adultR', 'adultGR', or 'compasWB'
% 'svar' is the sensitive variable
datasetName = 'credit';
[dataAll, svarAll, groupNames] = loadData(datasetName);

max_sample = 500;
num_trials = 5;
samples = 1:max_sample;

data = dataAll;
svar = svarAll;
svar = svar(samples);
groups = unique(svar);

% This function normalize the data (mean 0 & variance 1) + does (fair) PCA
% The second argument determines whether the PCA is fair or not
datawoPCA = normalizeData(data);
numPoints = size(datawoPCA(samples,:), 1);
numCenters = size(datawoPCA(samples,:), 1);
dists = zeros(numPoints, numCenters);
for i = 1:numCenters
    dists(:, i) = vecnorm(datawoPCA(samples,:) - datawoPCA(i,:),2,2);
end

ks = 5:5:50;
epsilons = 0.1:0.1:0.5;
lambdas = 0.2:0.2:1;
gammas = 0.1:0.1:0.4;
our_bi_time = zeros(length(ks), length(lambdas));
our_exact_time = zeros(length(ks), length(lambdas));
their_bi_time = zeros(length(ks), length(epsilons));
their_exact_time = zeros(length(ks), length(gammas));

for r=1:num_trials
our_bi_centers = cell(length(ks),length(lambdas));
our_bi_obj = zeros(length(ks),length(lambdas),length(groups));
our_bi_cen_num = zeros(length(ks),length(lambdas));
our_exact_centers = cell(length(ks),length(lambdas));
our_exact_obj = zeros(length(ks),length(lambdas),length(groups));
their_bi_centers = cell(length(ks),length(epsilons));
their_bi_obj = zeros(length(ks),length(epsilons),length(groups));
their_bi_cen_num = zeros(length(ks),length(epsilons));
their_exact_centers = cell(length(ks),length(gammas));
their_exact_obj = zeros(length(ks),length(gammas),length(groups));
for i = 1:length(ks)
    disp(ks(i))
    disp('our bicriteria is running...')
    for j = 1:length(lambdas)
        disp(strcat(num2str(j), '/', num2str(length(lambdas))))
        tic
        our_bi_centers{i,j} = iterative_rounding(dists,svar,ks(i),lambdas(j),1);
        our_bi_time(i,j) = our_bi_time(i,j) + toc;
        our_bi_centers{i,j} = our_bi_centers{i,j}(our_bi_centers{i,j} ~= 0);
        our_bi_cen_num(i,j) = length(our_bi_centers{i,j});
        for t = 1:length(groups)
            our_bi_obj(i,j,t) = sum(min(dists(svar==groups(t),our_bi_centers{i,j}),[],2)) / sum(svar==groups(t));
        end
    end
    disp('their bicriteria is running...')
    for j = 1:length(epsilons)
        disp(strcat(num2str(j), '/', num2str(length(epsilons))))
        tic
        their_bi_centers{i,j} = other_bicriteria(dists,svar,ks(i),epsilons(j));
        their_bi_time(i,j) = their_bi_time(i,j) + toc;
        their_bi_centers{i,j} = their_bi_centers{i,j}(their_bi_centers{i,j} ~= 0);
        their_bi_cen_num(i,j) = length(their_bi_centers{i,j});
        for t = 1:length(groups)
            their_bi_obj(i,j,t) = sum(min(dists(svar==groups(t),their_bi_centers{i,j}),[],2)) / sum(svar==groups(t));
        end
    end
    disp('their exact is running...')
    for j = 1:length(gammas)
        disp(strcat(num2str(j), '/', num2str(length(gammas))))
        tic
        their_exact_centers{i,j} = other_exact(dists,svar,ks(i),gammas(j));
        their_exact_time(i,j) = their_exact_time(i,j) + toc;
        for t = 1:length(groups)
            their_exact_obj(i,j,t) = sum(min(dists(svar==groups(t),their_exact_centers{i,j}),[],2)) / sum(svar==groups(t));
        end
    end
end
end

for r=1:num_trials
for i = 1:length(ks)
    disp(ks(i))
    disp('our exact is running...')
    for j = 1:length(lambdas)
        tic
        our_exact_centers{i,j} = exhaustive_search(dists,svar,ks(i),our_bi_centers{i,j});
        our_exact_time(i,j) = our_exact_time(i,j) + toc;
        for t = 1:length(groups)
            our_exact_obj(i,j,t) = sum(min(dists(svar==groups(t),our_exact_centers{i,j}),[],2)) / sum(svar==groups(t));
        end
    end
end
end

our_bi_time = our_bi_time / num_trials;
our_exact_time = our_exact_time / num_trials;
their_bi_time = their_bi_time / num_trials;
their_exact_time = their_exact_time / num_trials;
save(['../OutputApproxFairKMeans/', datasetName, '_', num2str(max_sample), '_9_26_2022']);