classdef (Abstract) RLagent
% RLagent - abstract RLagent
%
% Syntax:
%   obj = RLagent(actorNN,criticNN,options)
%
% Inputs:
%   actorNN - neural network
%   criticNN - neural network 
%   options.rl - options for RL:
%        .gamma: .99 (default) Discount factor applied to future rewards 
%               during training.       
%        .tau: .005 (default) Smoothing factor for target network updates
%        .expNoise: .2 (default) Exploration noise std. deviation.
%        .expNoiseType: 'OU' (default) Exploration noise type specified as 
%               an OrnsteinUhlenbeck (OU) or Gaussian.
%        .expDecayFactor: 1 (default) Exploration noise decay factor.
%               Linearly decreasing for <0 and exponentially for [0,1). 
%        .batchsize: 64 (defualt) Batchsize for neural network updates.
%        .buffersize: 1e6 (default) Experience buffer size
%        .noise: .1 (defualt) Perturbation noise radius for adv. and
%               set-based training methods.
%        .earlyStop: inf (default) Number of episodes for which the reward
%               did not increase before early stopping training.
%        .printFreq: 50 (default) Priniting frequency of training log.
%        .visRate: 50 (default) Visualisation frequency of learning
%               progress.
%        .actor.nn - Evaluation paramteres for the actor network
%           .poly_method: 'bounds' (default) Regression polynomial
%           .use_approx_error: true (default) Use approximation errors
%           .train - Training parameters for the actor network:
%               .use_gpu: true if available (default) Use CPU for training.
%               .optim: nnAdamOptimizer(1e-3,.9,.999,1e-8,1e-2) (default)
%                       Actor optimizer.
%               .backprop: true (default) Training boolean for actor.
%               .method: 'point'(default) Training method for actor:
%                   'point' Standard point-based training 
%                   'set' Set-based training [1]
%                   'random' Random adv. samples form perturbation ball
%                   'extreme' Adv. samples from edges of perturbation ball
%                   'naive' Adv. samples from naive algorithm [2]
%                   'grad' Adv. samples from grad algorthm [2]
%               .eta: 0.01 (default) Weighting factor for set-based 
%                   training of the actor.
%               .advOps - Parameters for adverserial training algs:
%                   .numSamples: 200 (default) Number of samples
%                   .alpha: 4 (default) Distribution paramater
%                   .beta: 4 (default) Distribution paramater
%               .zonotope_weight_update: 'outer_product' (default)
%                   Zonontope weight update for learnable params \theta
%        .critic.nn - Evaluation paramteres for the critic network:
%           .poly_method: 'bounds' (default) Regression polynomial
%           .use_approx_error: true (default) Use approximation errors
%           .train - Training parameters for the critic network:
%               .use_gpu: true if available (default) Use CPU for training.
%               .optim: nnAdamOptimizer(1e-3,.9,.999,1e-8) (default)
%                       Critic optimizer.
%               .backprop: true (default) Training boolean for critic.
%               .method: 'point'(default) Training method for critic:
%                   'point' Standard point-based training 
%                   'set' Set-based training [1]
%               .eta: 0.01 (default) Weighting factor for set-based 
%                   training of the critic.
%               .zonotope_weight_update: 'outer_product' (default)
%                   Zonontope weight update for learnable params \theta
%   
% Outputs:
%   obj - generated object
% 
% Refernces:
%   [1] Wendl, M. et al. Training Verifiably Robust Agents Using Set-Based 
%       Reinforcement Learning, 2024
%   [2] Pattanaik, A. et al. Robust Deep Reinforcement Learning with 
%       Adversarial Attacks, Int. Conf. on Autonomous Agents and Multiagent 
%       Systems (AAMAS) 2018   
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: actor

% Authors:       Manuel Wendl
% Written:       18-August-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

properties
    actor
    critic
    targetActor
    targetCritic
    buffer
    options
    learnHistory
end

methods
    % contructor
    function obj = RLagent(actorNN,criticNN,options)
        % Validate function arguments
        inputArgsCheck({ ...
            {actorNN,'att','neuralNetwork'}; ...
            {criticNN,'att','neuralNetwork'}; ...
            {options,'att','struct'}; ...
            });
        % Set default training parameters
        options = aux_validateRlOptions(options);

        % instantiate actor and critic
        [obj.actor, numGensActor, options] = actor(actorNN.copyNeuralNetwork(),options);
        [obj.critic, options] = critic(criticNN.copyNeuralNetwork(),numGensActor,options);

        % Instantiate target networks
        [obj.targetActor, numGensActor, options] = actor(actorNN.copyNeuralNetwork(),options);
        [obj.targetCritic, options] = critic(criticNN.copyNeuralNetwork(),numGensActor,options);

        % instantiate buffer
        obj.buffer = buffer(options.rl.buffersize);

        % store options
        obj.options = options;
    end

    function sobj = saveobj(obj)
        sobj = obj;
        sobj.buffer.array = [];
        s = rng;
        sobj.options.rl.trainSeed = s.Seed;
    end

end

methods  (Access=protected, Abstract)
    [obj,learnHistory] = trainNetworkStep(obj,randBatch, noiseBatchG, learnHistory, episode)
    obj = gatherNetworks(obj)
    obj = deleteAllGradients(obj)
end


end


% Auxiliary functions -----------------------------------------------------

% set defaul values for the DDPGagent
function options = aux_validateRlOptions(options)

% default values
persistent defaultRlFields
if isempty(defaultRlFields)
    defaultRlFields = {
        'gamma', .99;        
        'tau', .005;
        'expNoise', .2;
        'expNoiseTarget', .2;
        'expNoiseType', 'OU';
        'expDecayFactor', 1;
        'batchsize', 64;
        'buffersize', 1e6;
        'noise', .1;
        'earlyStop', inf;
        'printFreq', 50;
        'visRate', 50;
        };
end

persistent defaultNnFields
if isempty(defaultNnFields)
    defaultNnFields = {
        'bound_approx', true;
        'num_generators', [];
        'max_gens_post', inf;
        'add_approx_error_to_GI', false;
        'plot_multi_layer_approx_info', false;
        'reuse_bounds', false;
        'max_bounds', 5;
        'do_pre_order_reduction', false;
        'remove_GI', true;
        'force_approx_lin_at', Inf;
        'propagate_bounds', false;
        'sort_exponents', false;
        'maxpool_type', 'project';
        'order_reduction_sensitivity', false;
        'graph', graph();
        'use_approx_error', true;
        'train', struct();
        'poly_method', 'bounds';
        'interval_center', false;
        };
end

persistent defaultActorTrainFields
if isempty(defaultActorTrainFields)
    defaultActorTrainFields = {
        'use_gpu', aux_isGPUavailable();
        'optim', nnAdamOptimizer(1e-4,.9,.999,1e-8,0);
        'backprop', true;
        'method', 'point';
        'eta', 0.1;
        'omega', 0.5;
        'exact_backprop', false;
        'num_init_gens', inf;
        'zonotope_weight_update', 'outer_product';
        };
end

persistent defaultCriticTrainFields
if isempty(defaultCriticTrainFields)
    defaultCriticTrainFields = {
        'use_gpu', aux_isGPUavailable();
        'optim', nnAdamOptimizer(1e-3,.9,.999,1e-8,1e-2);
        'backprop', true;
        'method', 'point';
        'eta', 0.01;
        'exact_backprop', false;
        'num_init_gens', inf;
        'zonotope_weight_update', 'outer_product';
        };
end

persistent defaultAdvOps
if isempty(defaultAdvOps)
    defaultAdvOps = {
       'numSamples', 200;
       'alpha', 4;
       'beta', 4;
    };
end

% check if any rl options are given
if ~isfield(options,'rl')
    options.rl = struct;
end

% check if actor options are given
if ~isfield(options.rl,'actor')
    options.rl.actor = struct('nn',struct());
end

% check if critic options are given
if ~isfield(options.rl,'critic')
    options.rl.critic = struct('nn',struct());
end


% set default value of rl fields if required
for i=1:size(defaultRlFields, 1)
    field = defaultRlFields{i, 1};
    if ~isfield(options.rl, field)
        fieldValue = defaultRlFields{i, 2};
        if isa(fieldValue, "function_handle")
            fieldValue = fieldValue(options);
        end
        options.rl.(field) = fieldValue;
    end
end

% set default value of actor and critic nn fields if required
for i=1:size(defaultNnFields, 1)
    field = defaultNnFields{i, 1};
    if ~isfield(options.rl.actor.nn, field)
        fieldValue = defaultNnFields{i, 2};
        if isa(fieldValue, "function_handle")
            fieldValue = fieldValue(options);
        end
        options.rl.actor.nn.(field) = fieldValue;
    end
    if ~isfield(options.rl.critic.nn, field)
        fieldValue = defaultNnFields{i, 2};
        if isa(fieldValue, "function_handle")
            fieldValue = fieldValue(options);
        end
        options.rl.critic.nn.(field) = fieldValue;
    end
end

% set default value of actor nn.train fields if required
for i=1:size(defaultActorTrainFields, 1)
    field = defaultActorTrainFields{i, 1};
    if ~isfield(options.rl.actor.nn.train, field)
        fieldValue = defaultActorTrainFields{i, 2};
        if isa(fieldValue, "function_handle")
            fieldValue = fieldValue(options);
        end
        options.rl.actor.nn.train.(field) = fieldValue;
    end
end

% set default values of acvOps fields if required
if ~isfield(options.rl.actor.nn.train,'advOps')
    options.rl.actor.nn.train.advOps = struct;
end
for i=1:size(defaultAdvOps, 1)
    field = defaultAdvOps{i, 1};
    if ~isfield(options.rl.actor.nn.train.advOps, field)
        fieldValue = defaultAdvOps{i, 2};
        if isa(fieldValue, "function_handle")
            fieldValue = fieldValue(options);
        end
        options.rl.actor.nn.train.advOps.(field) = fieldValue;
    end
end

% set default value of critic nn.train fields if required
for i=1:size(defaultCriticTrainFields, 1)
     field = defaultCriticTrainFields{i, 1};
    if ~isfield(options.rl.critic.nn.train, field)
        fieldValue = defaultCriticTrainFields{i, 2};
        if isa(fieldValue, "function_handle")
            fieldValue = fieldValue(options);
        end
        options.rl.critic.nn.train.(field) = fieldValue;
    end
end

% Check rl fields
if CHECKS_ENABLED
    structName = inputname(1);
    % Check rl fields
    aux_checkFieldNumericDefInterval(options.rl,'gamma',interval(0,1),structName)
    aux_checkFieldNumericDefInterval(options.rl,'tau',interval(0,1),structName)
    aux_checkFieldNumericDefInterval(options.rl,'expNoise',interval(0,inf),structName)
    aux_checkFieldNumericDefInterval(options.rl,'expNoiseTarget',interval(0,inf),structName)
    aux_checkFieldStr(options.rl, 'expNoiseType', {'OU','gaussian'}, structName);
    aux_checkFieldNumericDefInterval(options.rl,'expDecayFactor',interval(-1,1),structName)
    aux_checkFieldNumericDefInterval(options.rl,'batchsize',interval(0,inf),structName)
    aux_checkFieldNumericDefInterval(options.rl,'noise',interval(0,inf),structName)

    % Check actor fields
    aux_checkFieldClass(options.rl.actor.nn, 'bound_approx', ...
        {'logical', 'string'}, structName);
    if isa(options.rl.actor.nn.bound_approx, 'string')
        aux_checkFieldStr(options.rl.actor.nn, 'bound_approx', {'sample'}, structName)
        CORAwarning('CORA:nn',"Choosing Bound estimation '%s' does not lead to safe verification!", ...
            options.rl.actor.nn.bound_approx);
    end
    aux_checkFieldStr(options.rl.actor.nn, 'poly_method', {'bounds'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'num_generators', {'double'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'add_approx_error_to_GI', ...
        {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'plot_multi_layer_approx_info', ...
        {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'reuse_bounds', {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'max_bounds', {'double'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'do_pre_order_reduction', ...
        {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'max_gens_post', {'double'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'remove_GI', {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'force_approx_lin_at', {'double'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'propagate_bounds', {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'sort_exponents', {'logical'}, structName);
    aux_checkFieldStr(options.rl.actor.nn, 'maxpool_type', {'project', 'regression'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'order_reduction_sensitivity', {'logical'}, structName);
    aux_checkFieldClass(options.rl.actor.nn, 'graph', {'graph'}, structName);
    aux_checkFieldClass(options.rl.actor.nn,'use_approx_error',{'logical'},structName);
    aux_checkFieldClass(options.rl.actor.nn.train,'optim', ...
        {'nnSGDOptimizer','nnAdamOptimizer'},structName);
    aux_checkFieldStr(options.rl.actor.nn.train,'method', ...
        {'point','set','rand','extreme','naive','grad'},structName);
    aux_checkFieldNumericDefInterval(options.rl.actor.nn.train,'eta',interval(0,inf),structName)
    aux_checkFieldNumericDefInterval(options.rl.actor.nn.train,'omega',interval(0,1),structName)
    aux_checkFieldStr(options.rl.actor.nn.train,'zonotope_weight_update', ...
        {'center','sample','extreme','outer_product','sum'},structName);

    % Check critic fields
    aux_checkFieldClass(options.rl.critic.nn, 'bound_approx', ...
        {'logical', 'string'}, structName);
    if isa(options.rl.critic.nn.bound_approx, 'string')
        aux_checkFieldStr(options.rl.critic.nn, 'bound_approx', {'sample'}, structName)
        CORAwarning('CORA:nn',"Choosing Bound estimation '%s' does not lead to safe verification!", ...
            options.rl.critic.nn.bound_approx);
    end
    aux_checkFieldStr(options.rl.critic.nn, 'poly_method', {'bounds'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'num_generators', {'double'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'add_approx_error_to_GI', ...
        {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'plot_multi_layer_approx_info', ...
        {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'reuse_bounds', {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'max_bounds', {'double'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'do_pre_order_reduction', ...
        {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'max_gens_post', {'double'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'remove_GI', {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'force_approx_lin_at', {'double'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'propagate_bounds', {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'sort_exponents', {'logical'}, structName);
    aux_checkFieldStr(options.rl.critic.nn, 'maxpool_type', {'project', 'regression'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'order_reduction_sensitivity', {'logical'}, structName);
    aux_checkFieldClass(options.rl.critic.nn, 'graph', {'graph'}, structName);
    aux_checkFieldClass(options.rl.critic.nn,'use_approx_error',{'logical'},structName);
    aux_checkFieldClass(options.rl.critic.nn.train,'optim', ...
        {'nnSGDOptimizer','nnAdamOptimizer'},structName);
    aux_checkFieldStr(options.rl.critic.nn.train,'method', ...
        {'point','set'},structName);
    aux_checkFieldNumericDefInterval(options.rl.critic.nn.train,'eta',interval(0,inf),structName)
    aux_checkFieldStr(options.rl.critic.nn.train,'zonotope_weight_update', ...
        {'center','sample','extreme','outer_product','sum'},structName);

    % Validate the train options 
    if strcmp(options.rl.critic.nn.train.method,'set')
        if ~strcmp(options.rl.actor.nn.train.method,'set')
            throw(CORAerror('CORA:wrongFieldValue', ...
        'critic.nn.train.method', 'set'))
        end
    end
end
end

function gpu_available = aux_isGPUavailable()
    try
        if ~isempty(which('gpuDeviceCount'))
            gpu_available = gpuDeviceCount('available') > 0;
        else
            gpu_available = false;
        end
    catch
        gpu_available = false;
    end
end

function aux_checkFieldStr(optionsrl, field, admissibleValues, structName)
fieldValue = optionsrl.(field);
if ~(isa(fieldValue, 'string') || isa(fieldValue, 'char')) || ...
        ~ismember(fieldValue, admissibleValues)
    throw(CORAerror('CORA:wrongFieldValue', ...
        aux_getName(structName, field), admissibleValues))
end
end

function aux_checkFieldClass(optionsrl, field, admissibleClasses, structName)
if ~ismember(class(optionsrl.(field)), admissibleClasses)
    throw(CORAerror('CORA:wrongFieldValue', ...
        aux_getName(structName, field), admissibleClasses))
end
end

function aux_checkFieldNumericDefInterval(optionsrl, field, interval, structName)
    if optionsrl.(field)<interval.inf || optionsrl.(field)>interval.sup
        throw(CORAerror('CORA:wrongFieldValue', ...
            aux_getName(structName, field), interval))
    end
end

function msg = aux_getName(structName, field)
    msg = sprintf("%s.nn.%s", structName, field);
end


% ------------------------------ END OF CODE ------------------------------
