d = 50; % dimension of the data

M = 20;  % degree of the graph

N = 2 * 50;  % total number of training samples

p_ratio = 0.5; % pruning rate

N_p = N / 2; % number of positive data

N_g = N / 2; % number of negative data

K = 200; % 2*K = number of neurons

N_test = 20000; % number of test data

L = 50; % number of discriminative and non-discriminative patterns

T = 500; % number of maximum iterations

mag = 0.5; % magnitude of each pattern

sigma = 0.2; % noise level

O = generate_O( d , L, mag, sigma );

[ X, Y , ~ ] = generate_GX( N , O , M ); % generate the input vector

W_0 = 1 / 10 * randn( d , K ); % initilize W

V_0 = 1 / 10 * randn( d , K ); % initilize W

eta = 1 * K; % stepsize

training_error = nan( T , 2 ); % record of training error

% training with pruning

tic;

W = W_0;

V = V_0;

for t = 1 : 5
    
    [ g_W , g_V ] = Gradient_GNN( W , V , X , Y);
    
    W = W - eta * g_W;

    V = V - eta * g_V;
    
end

fprintf( 'With pruning, Loop Number =        ' );

temp_1 = sort( diag(W'*W) );

thr = floor( K * p_ratio ); 

thr_v = temp_1( thr );

W = W_0( : , ( diag(W'*W) > thr_v ) ); 

temp_2 = sort( diag(V'*V) );

thr_v = temp_2( thr );

V = V_0( : , ( diag(V'*V) > thr_v ) ); 

for t = 1 : T
   
    fprintf('\b\b\b\b\b\b\b%6d\n', t);
    
    [ g_W , g_V ] = Gradient_GNN( W , V , X , Y);
    
    W = W - eta * g_W;

    V = V - eta * g_V;
    
    error_t =  mean( max( 1 -  generate_y( W , V , X).* Y , 0) ); 
    
    training_error( t , 1 ) = error_t;
        
    if error_t == 0 || ( norm( g_W , 'fro' ) == 0 && norm( g_V , 'fro' ) == 0 )
        
        break;
        
    end

end

time_pruning = toc;

% training without pruning 

tic;

W_prime = W_0;

V_prime = V_0;

fprintf( 'Without pruning, Loop Number =        ' );

for t = 1 : T
   
    fprintf('\b\b\b\b\b\b\b%6d\n', t);
    
    [ g_W , g_V ] = Gradient_GNN( W_prime , V_prime , X , Y);
    
    W_prime = W_prime - eta * g_W;

    V_prime = V_prime - eta * g_V;
    
    error_t =  mean( max( 1 -  generate_y( W_prime , V_prime , X).* Y , 0) ); 
    
    training_error( t , 2 ) = error_t;
        
    if error_t == 0 || ( norm( g_W , 'fro' ) == 0 && norm( g_V , 'fro' ) == 0 )
        
        break;
        
    end

end

time_nonpruning = toc;

% plot the error

plot( 1 : T , training_error );

xlabel('Number of iterations');

ylabel('Training loss');

legend('Model after magnitude pruning','Original model');

[ X_test , Y_test ] = generate_GX( N , O , M );

Y_est = generate_y( W , V , X_test );

test_error = mean(  sign(Y_est) ~= sign( Y_test ) );
 
fprintf('With pruning, the training time  = %.2f seconds, and the test error = %f\n' , time_pruning, test_error );


Y_est = generate_y( W_prime , V_prime , X_test );

test_error = mean(  sign(Y_est) ~= sign( Y_test ) );
 
fprintf('Without pruning, the training time  = %.2f seconds, the test error = %f\n' , time_nonpruning, test_error );






