function error = MSDRL_learning(Mdl,xTr,yTr,xTe,yTe)
yTe(yTe == 2) = -1;
for s = 1:Mdl.distribs
    yTr{s}(yTr{s} == 2) = -1;
end
for s = 1:Mdl.distribs
    % Construct initial ML estimator F
    fMdl{s} = Mdl;
    fMdl{s}.loss = 'log';
    fMdl{s}.fmap = 'linbin';
    fMdl{s} = model_lr(xTr{s},yTr{s});
    % Construct sample splitting density ratio estimator w_A, w_B
    n(s) = size(xTr{s},1);
    idxAB = randsample(n(s),n(s));
    idxA = idxAB(1:floor(n(s)/2));
    idxB = idxAB(floor(n(s)/2)+1:end);
    xTrA{s,1} = xTr{s}(idxA,:);
    yTrA{s,1} = yTr{s}(idxA);
    xTrB{s,1} = xTr{s}(idxB,:);
    yTrB{s,1} = yTr{s}(idxB);
    betaA{s,1} = PenalizedLL(xTrA{s},xTe);
    betaB{s,1} = PenalizedLL(xTrB{s},xTe);
    % Construct sample splitting ML estimators f_A, f_B
    fMdlA{s} = Mdl;
    fMdlA{s}.loss = 'log';
    fMdlA{s}.fmap = 'linbin';
    fMdlA{s} = model_lr(xTrA{s},yTrA{s});
    fMdlB{s} = Mdl;
    fMdlB{s}.loss = 'log';
    fMdlB{s}.fmap = 'linbin';
    fMdlB{s} = model_lr(xTrB{s},yTrB{s});
end
% Compute D_A and D_B
DA = zeros(Mdl.distribs,Mdl.distribs);
DB = zeros(Mdl.distribs,Mdl.distribs);
yTe(yTe == 2) = 0;
for s = 1:Mdl.distribs
    yTr{s}(yTr{s} == 2) = 0;
end
for l = 1:Mdl.distribs
    for k = 1:Mdl.distribs
        for i = 1:size(xTrB{l},1)
            DA(l,k) = DA(l,k)+(1/size(xTrB{l},1))*betaB{l}(i)*f(fMdlA{k},xTrB{l}(i,:))*(f(fMdlA{l},xTrB{l}(i,:))-yTrB{l}(i));
        end
        for i = 1:size(xTrA{l},1)
            DB(l,k) = DB(l,k)+(1/size(xTrA{l},1))*betaA{l}(i)*f(fMdlB{k},xTrA{l}(i,:))*(f(fMdlB{l},xTrA{l}(i,:))-yTrA{l}(i));
        end
    end
end
% Compute sample split estimators G_A, G_B
GTA = zeros(Mdl.distribs,Mdl.distribs);
GTB = zeros(Mdl.distribs,Mdl.distribs);
GA = zeros(Mdl.distribs,Mdl.distribs);
GB = zeros(Mdl.distribs,Mdl.distribs);
G = zeros(Mdl.distribs,Mdl.distribs);
for l = 1:Mdl.distribs
    for k = 1:Mdl.distribs
        for j = 1:size(xTe,1)
            GTA(l,k) = GTA(l,k)+(1/size(xTe,1))*f(fMdlA{l},xTe(j,:))*f(fMdlA{k},xTe(j,:));
            GTB(l,k) = GTB(l,k)+(1/size(xTe,1))*f(fMdlB{l},xTe(j,:))*f(fMdlB{k},xTe(j,:));
        end
        GA(l,k) = GTA(l,k)-DA(l,k)-DA(k,l);
        GB(l,k) = GTB(l,k)-DB(l,k)-DB(k,l);
        % Compute bias-corrected estimator G
        G(l,k) = (GA(l,k)+GB(l,k))/2;
    end
end
% Make sure Gamma is semidefinite positive
eig_val     = eig(G);
[eig_vec,~] = eig(G);

% Ensure eigenvalues are non-negative
eig_val = max(eig_val, 1e-6);

% Reconstruct the matrix
G = eig_vec * diag(eig_val) * inv(eig_vec);
% Construct the data-dependent optimal weight
cvx_begin quiet
variable q(Mdl.distribs,1)
minimize( q'*G*q )
subject to
sum(q) == 1;
q >= zeros(Mdl.distribs,1);
cvx_end
% Return f = sum q*f_s
for i = 1:size(xTe)
    h = 0;
    for s = 1:Mdl.distribs
        h = h + q(s)*f(fMdl{s},xTe(i,:));
    end
    yPred(i,1)=sign(h);
end
error = sum(yPred ~= yTe) / numel(yTe);
end

%function h = f(Mdl,x)
%    h = 1/sum(exp(phi(Mdl,x,(1:Mdl.labels))*Mdl.mu-ones(Mdl.labels,1)*phi(Mdl,x,1)*Mdl.mu));
%end
function h = f(mu,x)
h = x*mu;
end

% function mu = model_lr(x,y)
% mu =[];
% cvx_begin quiet
% variable mu(size(x,2),1)
% minimize( -lr_mu(x,y,mu) )
% cvx_end
%
% end
%
% function value = lr_mu(x,y,mu)
% value=0;
% for i=1:size(x,1)
%     value = value+log(1+exp(-y(i)*x(i,:)*mu));
% end
% end

function mu = model_lr(x, y)
[n, d] = size(x);
mu = zeros(d, 1); % Initialize mu

cvx_begin quiet
variable mu(d)
minimize(sum(log(1 + exp(-y .* (x * mu)))))
cvx_end
end

function beta = PenalizedLL(xTr,xTe)

n = size(xTr,1);
t = size(xTe,1);

% Define the variables
X = [ones(n,1),xTr;ones(t,1),xTe];
G = [zeros(n,1);ones(t,1)];
for i = 1:size(xTr,2)+1
    normX(i) = norm(X(:,i));
end
lambda = 0.1*sqrt(log(size(xTr,2)+1)/(n+t));
% lambda = 0;

cvx_begin
variable b(size(xTr,2)+1,1)
minimize(log_b(X,G,b) + lambda*normX(2:end)*abs(b(2:end))/sqrt(n+t))
cvx_end

for i=1:n
    pG1 = exp([1,xTr(i,:)]*b)/(1+exp([1,xTr(i,:)]*b));
    beta(i) = (n/t)*(pG1/(1-pG1));
end
end

function value = log_b(X,G,b)
value=0;
for k=1:size(X,1)
    value = value+(log(1+exp(X(k,:)*b))-G(k)*X(k,:)*b)/size(X,1);
end
end