function [error_iter, error_iter_1, ell_2, ell_2_1, F_score, F_score_1] = valid_msda(data, K, T, values, theta_true)
    %data = data_1;
    M = size(data, 1) - 2; % number of machines of training set
    test_X = data{M+2, 1};  %last one : test dataset
    test_Y = data{M+2, 2};
    valid_X = data{M+1, 1};
    valid_Y = data{M+1, 2};
    p = size(test_X, 2);

    % evaluation index
    error_iter = zeros(T+1,1);
    error_iter_1 = zeros(T+1,1);
    ell_2 = zeros(T+1,1);
    ell_2_1 = zeros(T+1,1);
    F_score = zeros(T+1,1);
    F_score_1 = zeros(T+1,1);

    % initial error
    % error_iter = zeros(T+1,1);
    [theta_initial, error_initial, lambda_initial] = valid_initial(data, K, values); % take different lambda values;
    error_iter(1) = error_initial;
    error_iter_1(1) = error_initial;
    ell_2(1) = sum(norms(theta_initial - theta_true, 2, 1))/K;
    ell_2_1(1) = sum(norms(theta_initial - theta_true, 2, 1))/K;
    F_score(1) = F1_score(2*K, theta_initial);
    F_score_1(1) = F1_score(2*K, theta_initial);
    theta_initial_0 = theta_initial;
    

    %compute total covariance matrix and mean
    mu_hat = zeros(p, K);
    X_t = data{1,1};
    Y_t = data{1,2};
    for m = 2:M
        X_t = [X_t; data{m,1}];
        Y_t = [Y_t; data{m,3}];
    end
    cov_class = zeros(p, p);
    for k = 1:K
        ind_k = find(Y_t == k);
        X_class = X_t(ind_k,:);
        mu_hat(:,k) = mean(X_class, 1)';
        pi_hat(k) = size(X_class,1);
        cov_class = cov_class + size(X_class,1) * cov(X_class);
    end
    Sigma = cov_class/size(X_t,1);
    pi_hat = pi_hat/sum(pi_hat);
    delta_hat = mu_hat - mu_hat(:,1);
    delta_hat(:,1) = [];

    %compute local covariance matrix
    X1 = data{1,1};
    Y1 = data{1,2};
    Sigma_1 = zeros(p, p);
    for k = 1:K
        ind_k = find(Y1 == k);
        X_class1 = X1(ind_k,:);
        Sigma_1 = Sigma_1 + size(X_class1,1) * cov(X_class1);
    end
    Sigma_1 = Sigma_1 / size(X1 ,1);

    %%% Mean!!
    % compute Sigma_hat * theta_hat
    lambda_mean = zeros(T,1);
    for t = 1:T
        % update theta
        delta = (Sigma_1 - Sigma) * theta_initial + delta_hat;

        for i = 1:length(values)
            lambda = values(i);
            % blockwise coordinate descent for updating theta
            iter = 0;
            theta_update = zeros(p, K-1);
            theta_bar = zeros(p, K-1);
            while iter < 100
                    dif = 0;
                    for j = 1 : p
                        theta_bar = delta(j,:)/Sigma_1(j,j) - Sigma_1(j,:) * theta_update + Sigma_1(j,j) .* theta_update(j,:);
                        theta_tmp = theta_update(j,:);
                        bar_norm = norm(theta_bar, 2);
                        v = bar_norm - lambda/Sigma_1(j,j);
                        if v>0
                            theta_update(j,:) = theta_bar .* (v/bar_norm);
                        else
                            theta_update(j,:) = zeros(1,K-1);
                        end
                        d = theta_update(j,:) - theta_tmp;
                        dif = max(dif, max(abs(d)));
                    end
                    iter = iter + 1;
                    if dif < 1e-5
                        break
                    end
            end
            % compute error of t_th step
            theta_CV{i} = theta_update;
            theta_choose = [zeros(p, 1) theta_update];
            tpred_value = valid_X * theta_choose + log(pi_hat);
            [max_a, index] = max(tpred_value');
            index = index';
            error_cv(i) = 1 - mean(index == valid_Y);
        end
        lambda_min = values(find(error_cv == min(error_cv)));
        lambda_mean(t) = lambda_min(1);
        theta_initial = theta_CV{find(error_cv == min(error_cv))};
        % compute error of t_th step
        theta_choose = [zeros(p, 1) theta_initial];
        tpred_value = test_X * theta_choose + log(pi_hat);
        [max_a, index] = max(tpred_value');
        index = index';
        error_iter(t + 1) = 1 - mean(index == test_Y);
        ell_2(t + 1) = sum(norms(theta_initial - theta_true, 2, 1))/K;
        F_score(t + 1) = F1_score(2*K, theta_initial); 
    end
    
    
    %%% Median!!
    % median of means
    % compute delta tilde 
    lambda_median = zeros(T,1);
    mu_store = cell(M, 1);
    pi_store = zeros(M, K);
    for m = 1:M
        X = data{m,1};
        Y = data{m,3};
        for k = 1:K
            ind_k = find(Y == k);
            X_class = X(ind_k,:);
            mu_local(:,k) = mean(X_class,1)';
            pi_store(m,k) = size(X_class, 1);
        end
        mu_store{m} = mu_local;
        pi_store(m,:) = pi_store(m,:)/sum(pi_store(m,:));
    end
    mu_tilde = zeros(p,K);
    for k = 1:K
        mu_k = zeros(p, M-2);
        for m = 1:M
            mu_store_m = mu_store{m};
            mu_k(:,m) = mu_store_m(:,k);
        end
        mu_tilde(:,k) = median(mu_k, 2);
    end
    pi_tilde = median(pi_store, 1);
    delta_tilde = mu_tilde - mu_tilde(:,1);
    delta_tilde(:,1) = [];

    theta_initial = theta_initial_0;
    
    % iteration
    for t = 1:T
        % compute b_tilde using theta_initial
        grad_store = cell(M,1);
        for m = 1:M
            X = data{m,1};
            Y = data{m,3};
            cov_class = zeros(p, p);
            for k = 1:K
                ind_k = find(Y == k);
                X_class = X(ind_k,:);
                cov_class = cov_class + size(X_class,1) * cov(X_class);
            end
            Sigma_local = cov_class/size(X,1);
            grad_store{m} = Sigma_local * theta_initial;
        end
        grad_tilde = zeros(p,K-1);
        for k = 1:K-1
            grad_k = zeros(p, M);
            for m = 1:M
                b_store_m = grad_store{m};
                grad_k(:,m) = b_store_m(:,k);
            end
            grad_tilde(:,k) = median(grad_k, 2);
        end
        b_tilde = Sigma_1 * theta_initial - grad_tilde + delta_tilde;

        % choose lambda using CV
        for i = 1:length(values)
            lambda = values(i);
            % blockwise coordinate descent for updating theta
            theta_update = zeros(p, K-1);
            theta_bar = zeros(p, K-1);
            iter = 0;
            while iter < 100
                dif = 0;
                for j = 1 : p
                    theta_bar = b_tilde(j,:)/Sigma_1(j,j) - Sigma_1(j,:) * theta_update + Sigma_1(j,j) .* theta_update(j,:);
                    theta_tmp = theta_update(j,:);
                    bar_norm = norm(theta_bar, 2);
                    v = bar_norm - lambda/Sigma_1(j,j);
                    if v>0
                        theta_update(j,:) = theta_bar .* (v/bar_norm);
                    else
                        theta_update(j,:) = zeros(1,K-1);
                    end
                    d = theta_update(j,:) - theta_tmp;
                    dif = max(dif, max(abs(d)));
                end
                iter = iter + 1;
                if dif < 1e-5
                    break
                end
            end
            % compute error of t_th step
            theta_CV{i} = theta_update;
            theta_choose = [zeros(p, 1) theta_update];
            tpred_value = valid_X * theta_choose + log(pi_tilde);
            [max_a, index] = max(tpred_value');
            index = index';
            error_cv(i) = 1 - mean(index == valid_Y);
        end
        lambda_min = values(find(error_cv == min(error_cv)));
        lambda_median(t) = lambda_min(1);
        theta_initial = theta_CV{find(error_cv == min(error_cv))};
        % compute error of t_th step
        theta_choose = [zeros(p, 1) theta_initial];
        tpred_value = test_X * theta_choose + log(pi_tilde);
        [max_a, index] = max(tpred_value');
        index = index';
        error_iter_1(t + 1) = 1 - mean(index == test_Y);
        ell_2_1(t + 1) = sum(norms(theta_initial - theta_true, 2, 1))/K;
        F_score_1(t + 1) = F1_score(2*K, theta_initial);
    end
end