classdef nonlinearARX < contDynamics
% nonlinearARX class (time-discrete nonlinear ARX model)
%
% Syntax:
%    % only dynamic equation
%    obj = nonlinearARX(fun,dt,dim_y,dim_u,n_t)
%    obj = nonlinearARX(name,fun,dt,dim_y,dim_u,n_t)
%
% Description:
%    Generates a discrete-time nonlinear ARX object (NARX) according  
%    to the following equation:
%       y(k) = f(y(k-1),...,y(k-n_y),u(k),...,u(k-n_u),e(k-1),...,e(k-n_e)) 
%               + e(k)
%
% Inputs:
%    fun    - function handle for the NARX equation with arguments (y,u)
%               y=[y(k-1); ...; y(k-n_t)]: array dim_y x n_t
%               u=[u(k); ...; u(k-n_t)]: array dim_u x (n_t+1)        
%    name   - name of the model
%    dt     - sampling time
%    dim_y  - dimension of the output
%    dim_u  - dimension of the input
%    n_t    - number of past time steps which are considered
%
% Outputs:
%    obj - generated nonlinearSysDT object
%
% Example:
%    f = @(y,u) [y(1,1) + u(1,1) - y(2,1); ...
%                   y(3,1) + u(2,1)*cos(y(1,1)); ...
%                   y(5,1) + u(4,1)*sin(y(1,1))];
%    dt = 0.25;
%    sys = nonlinearARX(f,dt,3,2,2)
%
% See also: nonlinearSysDT, linearARX

% Authors:       Laura Luetzow
% Written:       24-April-2023
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------
  

properties (SetAccess = private, GetAccess = public)
    prev_ID = 0;                % previous assigned identifier for polyZon 
    mFile = [];                 % function handle dynamic equation
    jacobian = [];              % function handle jacobian matrix
    hessian = [];               % function handle hessian tensor
    thirdOrderTensor = [];      % function handle third-order tensor
    dt {mustBeNumeric} = [];    % sampling time
end

methods
    
    % class constructor
    function obj = nonlinearARX(varargin)
        
        % check number of input arguments
        if nargin < 2
            throw(CORAerror('CORA:notEnoughInputArgs',2));
        elseif nargin > 6
            throw(CORAerror('CORA:tooManyInputArgs',6));
        end

        % assign arguments
        [name,fun,dt,dim_y,dim_u,p] = aux_parseInputArgs(varargin{:});

        % check arguments
        aux_checkInputArgs(name,dt,dim_y,dim_u,p);        

        % instantiate parent class
        obj@contDynamics(name,p,dim_u,dim_y);


        % assign object properties
        obj.dt = dt;
        obj.mFile = fun;
        obj.jacobian = eval(['@jacobian_',name]);
        obj.hessian = eval(['@hessianTensor_' obj.name]);
        obj.thirdOrderTensor = eval(['@thirdOrderTensor_' obj.name]);

        obj.prev_ID = 10;
    end    

    function setPrevID(sys, id_new)
        sys.prev_ID = id_new;
    end

    function set = generate_indPolyZonotope(sys, set)
        % transform set to a polynomial zonotope with new identifiers

        if ~isa(set, 'polyZonotope')
            set = polyZonotope(set);
        end
        id_old = set.id;
        id_new = sys.prev_ID+1 : sys.prev_ID+length(set.id);
        set = replaceId(set, id_old, id_new);

        if ~isempty(id_new)
            setPrevID(sys, id_new(end));
        end
    end


    function R0 = getR0(sys, Y, type)
        % stack the previous p output sets to get the NARX 
        % state set for time k+1
        
        if nargin < 3
            type = "standard";
        end
        if nargin < 4
            k = length(Y);
        end

        if isa(Y, 'double')
            yVec = Y;
            Y = cell(size(Y,2),1);
            for j=k-sys.dim+1:k
                Y{j} = generate_indPolyZonotope(sys, yVec(:,j));
            end
        end

        for j=1:sys.dim
            if type == "poly" && ~isa(Y{j}, 'polyZonotope')
                Y{j} = generate_indPolyZonotope(sys, Y{j});
            end
            if j==1
                R0 = Y{j};
            else
                if isa(Y{j}, 'polyZonotope')
                    R0 = stack(R0, Y{j});
                else
                    R0 = cartProd(R0, Y{j});
                end
            end
        end

    end

    function [u_stacked, U_stacked] = getStackedU(sys, u, U, type)
        % stack the current and previous p inputs sets to get the NARX 
        % input sets for each time point

        if nargin < 4
            type = "standard";
        end

        u_stacked = zeros((sys.dim+1)*sys.nrOfInputs, size(u,2));
        if nargin >= 3 % compute stacked U-sets 
            if ~iscell(U)
                % U is constant set
                % --> create an indepenent set U for each time step

                U_const = U;
                U = cell(size(u,2),1);
                for j = 1:size(u,2)
                    if type == "poly"
                        U{j} = generate_indPolyZonotope(sys, U_const);
                    else
                        U{j} = U_const;
                    end
                end
            end
            
            U_stacked = cell(size(u,2), 1);
            for j = sys.dim+1:size(u,2)
                U_stacked_j = U{j} + u(:,j);
                for i=1:sys.dim
                    if isa(U{j-i}, "polyZonotope")
                        U_stacked_j = stack(U{j-i} + u(:,j-i),U_stacked_j);
                    else
                        U_stacked_j = cartProd(U{j-i} + u(:,j-i),U_stacked_j);
                    end
                end
                U_stacked{j} = U_stacked_j;
                u_stacked(:,j) = center(U_stacked_j);
            end
        else
            U_stacked = [];
            for j = sys.dim+1:size(u,2)
                u_stacked(:,j) = reshape(u(:,j-sys.dim:j),[],1);
            end
        end
    end

    % update system dynamics for the new augmented input [u; w] where w is
    % the process noise acting on all states 
    function obj = augment_u_with_w(obj)
        throw(CORAerror('CORA:notSupported','Not implemented for nonlinearARX.'))
    end

    % update system dynamics for the new augmented input [u; v] where v is
    % the measurement noise acting on all outputs 
    function obj = augment_u_with_v(obj)
        dim_y = obj.nrOfOutputs;
        idz_uold = [];
        dim_uold = obj.nrOfInputs;
        dim_unew = obj.nrOfInputs + dim_y;
        for i = 0:obj.dim
            idz_uold = [idz_uold i*dim_unew+1:i*dim_unew+dim_uold];
        end
        idz_unew = dim_uold + 1: dim_unew;

        obj.mFile = @(y,u) obj.mFile(y,u(idz_uold)) + u(idz_unew);
        obj.nrOfInputs = dim_unew;
    end
end
end


% Auxiliary functions -----------------------------------------------------

function [name,fun,dt,dim_y,dim_u,p] = aux_parseInputArgs(varargin)

    % parse input arguments
    if ischar(varargin{1}) || isa(varargin{1}, 'string')
        name = char(varargin{1});
        varargin = varargin(2:end);
    else
        name = 'nonlinearARX'; % default name
    end

    fun = varargin{1};
    dt = varargin{2};
    dim_y = varargin{3};
    dim_u = varargin{4};
    p = varargin{5};
end

function aux_checkInputArgs(name,dt,dim_y,dim_u,p)

    % check name (not empty because default name is not empty)
    if ~ischar(name)
        throw(CORAerror('CORA:wrongInputInConstructor',...
            'System name has to be a char array.'));
    end

    % sampling time has to be a scalar larger than zero
    inputArgsCheck({{dt,'att','numeric',{'positive','scalar'}}});

    % dim_y and dim_u have to be numeric, scalar integer > 0
    inputArgsCheck({{dim_y,'att','numeric',...
            {'integer','scalar'}}});
    inputArgsCheck({{dim_u,'att','numeric',...
            {'integer','scalar'}}});
    inputArgsCheck({{p,'att','numeric',...
            {'integer','scalar'}}});

end


% ------------------------------ END OF CODE ------------------------------
