function lambda_choose_1 = median_iter_CV(data, theta_initial, values, K)
    M = size(data, 1) - 1;
    p = size(data{1, 1}, 2);
    for i = 1:length(values)
        lambda = values(i);
        for j = 1:5
            % valid set and train set
            valid_X = [];
            valid_Y = [];
            for m = 1:M
                X_m = data{m, 1};
                Y_m = data{m, 2};
                BY_m = data{m, 3};
                n_m = size(X_m, 1);
                rowrank_m = randperm(n_m, round(n_m/5));
                data_tmp{m, 1} = X_m(setdiff(1:n_m, rowrank_m), :);
                data_tmp{m, 2} = Y_m(setdiff(1:n_m, rowrank_m));
                data_tmp{m, 3} = BY_m(setdiff(1:n_m, rowrank_m));
                valid_X = [valid_X; X_m(rowrank_m, :)];
                valid_Y = [valid_Y; Y_m(rowrank_m)];
            end

            % compute Sigma_1
            train_X1 = data_tmp{1,1};
            train_Y1 = data_tmp{1,2};
            cov_class = zeros(p,p);
            for k = 1:K
                ind_k = find(train_Y1 == k);
                X_class = train_X1(ind_k,:);
                cov_class = cov_class + size(X_class,1) * cov(X_class);
            end
            Sigma_1 = cov_class/size(train_X1,1);

            %% median of means
            mu_store = cell(M, 1);
            pi_store = zeros(M, K);
            for m = 1:M
                X = data{m,1};
                Y = data{m,3};
                for k = 1:K
                    ind_k = find(Y == k);
                    X_class = X(ind_k,:);
                    mu_local(:,k) = mean(X_class,1)';
                    pi_store(m,k) = size(X_class, 1);
                end
                mu_store{m} = mu_local;
                pi_store(m,:) = pi_store(m,:)/sum(pi_store(m,:));
            end
            mu_tilde = zeros(p,K);
            for k = 1:K
                mu_k = zeros(p, M);
                for m = 1:M
                    mu_store_m = mu_store{m};
                    mu_k(:,m) = mu_store_m(:,k);
                end
                mu_tilde(:,k) = median(mu_k, 2);
            end
            pi_tilde = median(pi_store, 1);
            delta_tilde = mu_tilde - mu_tilde(:,1);
            delta_tilde(:,1) = [];

            % compute MOM estimator and valid accuracy
            grad_store = cell(M,1);
            for m = 1:M
                X = data{m,1};
                Y = data{m,3};
                cov_class = zeros(p, p);
                for k = 1:K
                    ind_k = find(Y == k);
                    X_class = X(ind_k,:);
                    cov_class = cov_class + size(X_class,1) * cov(X_class);
                end
                Sigma_local = cov_class/size(X,1);
                grad_store{m} = Sigma_local * theta_initial;
            end
            grad_tilde = zeros(p,K-1);
            for k = 1:K-1
                grad_k = zeros(p, M);
                for m = 1:M
                    b_store_m = grad_store{m};
                    grad_k(:,m) = b_store_m(:,k);
                end
                grad_tilde(:,k) = median(grad_k, 2);
            end
            delta = Sigma_1 * theta_initial - grad_tilde + delta_tilde;
            
            theta_update = ISTA(Sigma_1, delta, lambda, 0.01);
            theta_1 = [zeros(p, 1) theta_update];
            pred_value_1 = bayes_value(valid_X, theta_1, mu_tilde, pi_tilde, K);
            [max_a, index] = max(pred_value_1');
            index = index';
            cv_accuracy_1(j) = mean(index == valid_Y);
        end
        lambda_accuracy_1(i) = mean(cv_accuracy_1);
    end
        lambda_choose_1 = values(find(lambda_accuracy_1 == max(lambda_accuracy_1))); % median
        lambda_choose_1 = lambda_choose_1(1);
end