%% This script creates data for pretraining.
clear;clc;
sdmt_info  = readtable('../../joint_analysis/sdmt_list_final.csv');

nback_info = readtable('../../joint_analysis/nback_list_final.csv');

load('../../data/data_final_pgc.mat','data_gene','data_snp','class');

nn = 5;

load(sprintf('../data/%d_layer_connection.mat',nn),'mask_con');

load(sprintf('../data/gene_pathway_layer_connection_and_input_data_%d.mat',nn),'data_gene','class','mask_gene_encode','pool_n');

%% 
% Identify loaction of genetic data who also have imaging data.
[~, test_id]      = ismember(unique([round(nback_info.familyno*100);round(sdmt_info.familyno*100)]), data_snp.IID);
train_id          = find(ismember([1:size(data_gene,1)],test_id)==0); 
class_train       = class(train_id)';
[train_id, cv_id] = split(train_id,class_train,length(test_id), sum(class(test_id))); % Create Train-Val split.

%% we need the gene scores as input.
gene_train = data_gene(train_id, :);
gene_cv    = data_gene(test_id, :);
gene_test  = data_gene(cv_id, :);

mask_gene = mask_gene_encode;
A = mask_con;

class_train = class(train_id)';
class_test  = class(cv_id)';
class_cv    = class(test_id)';

% A is the gene ontology based graph adajacency matrix, and mask_gene is
% the gene scores to patway binary matrix.
save(sprintf('../../data/train_test_data_gcn_%d_layers_joint_model.mat',nn), 'gene_train','gene_test','gene_cv','A','mask_gene','class_train','class_test','class_cv','pool_n');

function [tr_id, cv_id ]= split(id, class, n, n_d)

dis = find(class==1);
dis = dis(randperm(length(dis)));

con = find(class==0);
con = con(randperm(length(con)));


cv_id_dis = dis(randperm(length(dis),  n_d));    
cv_id_con = con(randperm(length(con), n-n_d )); 

cv_id = [cv_id_con',cv_id_dis'];
cv_id = id(cv_id(randperm(length(cv_id))));

tr_id = id(ismember(id,cv_id)==0);
end
