classdef logistic
    %LOGISTIC Construct a logistic loss function
    %   Detailed explanation goes here
    
    properties
        X;  % the data matrix n*p
        y;  % the labels
        mu;
        % f; % function handle
        % g; % gradient handle
        % H; % Hessian handle
    end

    methods(Static)
        function out = phi(t)
            % logistic function phi(t) = 1/(1+exp(-t))
            out = 0*t;
            idx = t>0;
            out(idx) = 1./(1+exp(-t(idx)));
            out(~idx) = exp(t(~idx))./(exp(t(~idx))+1);
        end

        function out = logit(t)
            % log of logistic function log phi(t) = -log(1+exp(-t))
            out = 0*t;
            idx = t>0;
            out(idx) = -log(1+exp(-t(idx)));
            out(~idx) = t(~idx) - log(1+exp(t(~idx)));
        end
    end
    
    methods
        function obj = logistic(X,y,mu)
            %LOGISTIC Construct an instance of this class
            %   Detailed explanation goes here
            obj.X = X;
            obj.y = y;
            obj.mu = mu;
        end
        
%         function outputArg = method1(obj,inputArg)
%             %METHOD1 Summary of this method goes here
%             %   Detailed explanation goes here
%             outputArg = obj.Property1 + inputArg;
%         end
       

        function out = loss(obj,w)
            % the loss function
            [n,~] = size(obj.X);
            y_hat = obj.X*w;
            out = -sum(obj.logit(y_hat.*obj.y))/n + obj.mu*(w'*w)/2;
        end

        function out = grad(obj,w)
            % computing the gradient
            [n,~] = size(obj.X);
            y_hat = obj.X*w;
            z = obj.y.*(obj.phi(y_hat.*obj.y)-1);
            out = obj.X'*z/n+obj.mu*w;
        end

        function out = hessian(obj,w)
            % computing the Hessian
            [n,p] = size(obj.X);
            y_hat = obj.X*w;
            z = obj.phi(y_hat.*obj.y);
            z = z.*(1-z);
            z_sr = sqrt(z);
            X_multi = obj.X.*z_sr;
            out = X_multi'*X_multi/n+obj.mu*eye(p);
            % out = obj.X'*diag(z)*obj.X/n+obj.mu*eye(p);
        end
    end
end

