classdef Vol_class < VBayesLab_vol

    properties

    end

    methods
        function obj = Vol_class(data,varargin)

            obj.Method = 'Vol_class';
            
            % Parse additional options
            if nargin > 2
                paramNames = ...
                    {'NumSample'        'LearningRate'      'GradWeight'       'GradClipInit' ...
                    'MaxIter'           'MaxPatience'       'WindowSize'       'Verbose' ...
                    'InitMethod'        'StdForInit'        'Seed'             'MeanInit' ...
                    'SigInitScale'      'LBPlot'            'GradientMax'      'DataTrain' ...
                    'Setting'           'StepAdaptive'      'SaveParams'...
                    'P'                 'O'                 'Q'                'GarchType'...
                    'Sampler'           'TrainTest'         'Optimizer'        'doCV'...
                    'useHfunc'};

                paramDflts = ...
                    {obj.NumSample      obj.LearningRate    obj.GradWeight    obj.GradClipInit ...
                    obj.MaxIter         obj.MaxPatience     obj.WindowSize    obj.Verbose ...
                    obj.InitMethod      obj.StdForInit      obj.Seed          obj.MeanInit ...
                    obj.SigInitScale    obj.LBPlot          obj.GradientMax   obj.DataTrain ...
                    obj.Setting         obj.StepAdaptive    obj.SaveParams ...
                    obj.P               obj.O               obj.Q             obj.GarchType...
                    obj.Sampler         obj.TrainTest       obj.Optimizer     obj.doCV...
                    obj.useHfunc};

                [obj.NumSample,...
                    obj.LearningRate,...
                    obj.GradWeight,...
                    obj.GradClipInit,...
                    obj.MaxIter,...
                    obj.MaxPatience,...
                    obj.WindowSize,...
                    obj.Verbose,...
                    obj.InitMethod,...
                    obj.StdForInit,...
                    obj.Seed,...
                    obj.MeanInit,...
                    obj.SigInitScale,...
                    obj.LBPlot,...
                    obj.GradientMax,...
                    obj.DataTrain,...
                    obj.Setting,...
                    obj.StepAdaptive,...
                    obj.SaveParams,...
                    obj.P,...
                    obj.O,...
                    obj.Q,...
                    obj.GarchType,...
                    obj.Sampler,...
                    obj.TrainTest,...
                    obj.Optimizer,...
                    obj.doCV,...
                    obj.useHfunc] = internal.stats.parseArgs(paramNames, paramDflts, varargin{:});

            end

            switch obj.GarchType
                case {'garch','egarch'}
                    obj.NumParams = obj.P+obj.O+obj.Q+1;
                case 'figarch'
                    obj.NumParams = obj.P+obj.Q+2;
                case 'rgarch'
                    obj.NumParams = obj.P+obj.Q+7;
            end

            obj.HFuntion = h_fun;

            if isempty(obj.MeanInit)
                mu = normrnd(0,obj.StdForInit,obj.NumParams,1);
                obj.InitMethod = 'Random';
                obj.MeanInit = mu;
            elseif obj.MeanInit == 0
                obj.InitMethod = 'Zeros';
                obj.MeanInit = zeros(obj.NumParams,1);
            else
                if size(obj.MeanInit,1) ~= obj.NumParams
                    error(['Initial mu must be of dimension (', num2str(obj.NumParams,'%i') 'x1)!'])
                end
                obj.InitMethod = 'Fixed';
            end


            % Main function to run QBVI
            obj.Post   = obj.fit(data);
        end

        function[tpar] = itransform(obj,par)
            tpar = garch_itransform(par,obj.P,obj.O,obj.Q,obj.GarchType);
        end

        function[SigSymm] = Symmetrize(obj)
            if ~all(all(obj.Post.Sig == obj.Post.Sig')) % not symmetric posterior
                warning([obj.Optimizer '.Post.Sig is not symmetric, so it is being symmetrized.'])
            end
            SigSymm = 1/2*(obj.Post.Sig + obj.Post.Sig');
        end

        function[dis] = get_xy_distribution(obj,Num_xp,k,trans)
            NumParams   = obj.NumParams;
            x        = zeros(Num_xp,NumParams);
            y        = zeros(Num_xp,NumParams);
            if ~trans

                %  Use kernel distr on posterior samples
                % c = mvnrnd(obj.Post.mu,obj.Post.Sig,80000);
                % rmea    = mean(c);
                % rstd    = std(c);
                % for i = 1:NumParams
                    %x(:,i) = linspace(rmea(i)-k*rstd(i),rmea(i)+k*rstd(i),Num_xp);
                    %y(:,i) = ksdensity(c(:,i),x(:,i));
                % end

                % Use posterior normal pdf
                pmea = obj.Post.mu;
                pstd = sqrt(obj.Post.Sig2);

                for i = 1:NumParams
                    x(:,i) = linspace(pmea(i)-k*pstd(i),pmea(i)+k*pstd(i),Num_xp);
                    y(:,i) = normpdf(x(:,i),pmea(i),pstd(i));
                end

            else
                SymmSig = Symmetrize(obj);
                N_draws = 80000;
                r = mvnrnd(obj.Post.mu,SymmSig,N_draws);
                c = zeros(N_draws,NumParams);

                for i = 1:N_draws
                    c(i,:) =  obj.itransform(r(i,:)');
                end

                rmea = mean(c);
                rstd = std(c);

                for i = 1:NumParams
                    x(:,i) = linspace(rmea(i)-k*rstd(i),rmea(i)+k*rstd(i),Num_xp);
                    y(:,i) = ksdensity(c(:,i),x(:,i));
                end

            end
            dis.x = x;
            dis.y = y;
            dis.k = k;
            dis.Num_xp = Num_xp;
            dis.transform = logical(trans);
        end

        %% QBVI main function

        function Post = fit(obj,data)

            d_theta = obj.NumParams;

            % Extract sampling setting  

            eps0            = obj.LearningRate;
            S               = obj.NumSample;           
            window_size     = obj.WindowSize;
            init_scale      = obj.SigInitScale;
            stepsize_adapt  = obj.StepAdaptive;
            hfunc           = obj.HFuntion;
            setting         = obj.Setting;
            
            % save_params     = obj.SaveParams;
            % ini_mu          = obj.MeanInit;
            % std_init        = obj.StdForInit;
            % use_tt          = obj.TrainTest;
            % optimizer       = obj.Optimizer;
            % max_patience    = obj.MaxPatience;
            % max_iter        = obj.MaxIter;
            % lb_plot         = obj.LBPlot;

            if obj.TrainTest
                if isempty(data.test)
                    error('Empty test data.')
                end
            else
                data.train = data.all;
            end

            % Initialization
            iter      = 0;
            patience  = 0;
            stop      = false;
            LB              = zeros(1,obj.MaxIter+1);
            LB_test         = zeros(1,obj.MaxIter+1);
            LB_smooth       = zeros(1,obj.MaxIter+1);
            LB_smooth_test  = zeros(1,obj.MaxIter+1);

            S0          = setting.Prior(2).*eye(d_theta);
            iS0         = inv(S0);
            mu0         = setting.Prior(1).*ones(d_theta,1);


            % Initialization of mu
            mu = obj.MeanInit;

            % Initialization of Sig
            Sig     = init_scale*eye(d_theta);
            Sig_inv = inv(Sig);

            log_pdf = @(theta,mu,Sig_inv) -d_theta/2*log(2*pi)+1/2*log(det(Sig_inv))-1/2*(theta-mu)'*Sig_inv*(theta-mu);

            nPar                        = d_theta+d_theta*d_theta;
            gra_log_q_lambda            = zeros(S,nPar);
            grad_log_q_h_function       = zeros(S,nPar);
            grad_log_q_h_function_cv    = zeros(S,nPar);
            c12                         = zeros(1,nPar);
            lb_log_h                    = zeros(S,1);

            if obj.TrainTest
                lb_log_h_test               = zeros(S,1);
            end

            if obj.SaveParams
                llh_s = zeros(S,1);
                pri_s = zeros(S,1);
                if obj.TrainTest
                    llh_test_s = zeros(S,1);
                    pri_test_s = zeros(S,1);
                end
            end

            theta_all = obj.sampler(mu,Sig);

            for s = 1:S
                % Parameters from Normal distribution
                theta = theta_all(:,s);

                % Log q_lambda
                log_q_lambda = log_pdf(theta,mu,Sig_inv);

                % h function
                [h_theta,llh,pri] = hfunc(data.train,theta,obj);                

                % h function
                h_function = h_theta - log_q_lambda;

                if obj.useHfunc
                    f = h_function;
                else
                    f = llh;
                end                

                % Compute the lowerbound
                lb_log_h(s) = h_function;
                if obj.TrainTest
                    [h_theta_test,llh_test,pri_test]  = hfunc(data.test,theta,obj);
                    h_function_test     = h_theta_test-log_q_lambda;
                    lb_log_h_test(s)    = h_function_test;
                end

                if obj.SaveParams
                    llh_s(s,1) = llh;
                    pri_s(s,1) = pri;
                    if obj.TrainTest
                        llh_test_s(s,1) = llh_test;
                        pri_test_s(s,1) = pri_test;
                    end
                end

                aux                           = (theta-mu);
                gra_log_q_mu                  = aux;
                gra_log_q_Sig                 = obj.fun_gra_log_q_Sig(Sig_inv,aux);
                gra_log_q_lambda(s,:)         = [gra_log_q_mu;gra_log_q_Sig(:)]';
                grad_log_q_h_function(s,:)    = gra_log_q_lambda(s,:)*f;
                grad_log_q_h_function_cv(s,:) = gra_log_q_lambda(s,:).*(f-c12);
            end

            c12 = obj.control_variates(grad_log_q_h_function,gra_log_q_lambda);
            Y12 = mean(grad_log_q_h_function_cv)';

            % Gradient clipping at the beginning
            Y12 = obj.grad_clipping(Y12,1);

            [gradLB_mu_momentum,gradLB_Sig_momentum] = obj.LBgradients(Y12,d_theta,[],[],Sig,[],Sig_inv,iS0,mu0,mu);

            LB0 = mean(lb_log_h);
            if obj.Verbose ~=0
                disp(['Iter: 0000 |LB: ', num2str(LB0)])
            end

            % Prepare for the next iterations
            mu_best     = mu;
            Sig_best    = Sig;

            while ~stop

                iter = iter+1;
                if iter>stepsize_adapt
                    stepsize = eps0*stepsize_adapt/iter;
                else
                    stepsize = eps0;
                end


                [mu,Sig,Sig_inv,Sig_old] = obj.update(stepsize,mu,Sig,Sig_inv,gradLB_Sig_momentum,gradLB_mu_momentum);


                gra_log_q_lambda            = zeros(S,nPar);
                grad_log_q_h_function       = zeros(S,nPar);
                grad_log_q_h_function_cv    = zeros(S,nPar);
%                 c12                         = zeros(1,nPar);

                theta_all = obj.sampler(mu,Sig);

                for s = 1:S
                % Parameters from Normal distribution
                theta = theta_all(:,s);

                % Log q_lambda
                log_q_lambda = log_pdf(theta,mu,Sig_inv);

                % h function
                [h_theta,llh,pri] = hfunc(data.train,theta,obj);                

                % h function
                h_function = h_theta - log_q_lambda;

                if obj.useHfunc
                    f = h_function;
                else
                    f = llh;
                end                

                % Compute the lowerbound
                lb_log_h(s) = h_function;
                if obj.TrainTest
                    [h_theta_test,llh_test,llp_test] = hfunc(data.test,theta,obj);
                    h_function_test = h_theta_test - log_q_lambda;
                    h_theta_test_s(s,1) = h_theta_test;
                    llh_test_s(s,1) = llh_test;
                    llp_test_s(s,1) = llp_test;
                    lb_log_h_test(s)    = h_function_test;
                end

                if obj.SaveParams
                    h_theta_s(s,1) = h_theta;
                    llh_s(s,1) = llh;
                    llp_s(s,1) = pri;
                end

                aux                           = (theta-mu);
                gra_log_q_mu                  = aux;
                gra_log_q_Sig                 = obj.fun_gra_log_q_Sig(Sig_inv,aux);
                gra_log_q_lambda(s,:)         = [gra_log_q_mu;gra_log_q_Sig(:)]';
                grad_log_q_h_function(s,:)    = gra_log_q_lambda(s,:)*f;
                grad_log_q_h_function_cv(s,:) = gra_log_q_lambda(s,:).*(f-c12);

                end

                c12 = obj.control_variates(grad_log_q_h_function,gra_log_q_lambda);
                Y12 = mean(grad_log_q_h_function_cv)';

                % Clipping the gradient
                Y12 = obj.grad_clipping(Y12,0);

                [gradLB_mu_momentum,gradLB_Sig_momentum] = obj.LBgradients(Y12,d_theta,gradLB_Sig_momentum,gradLB_mu_momentum,Sig,Sig_old,Sig_inv,iS0,mu0,mu);

                % Lower bound
                LB(iter) = mean(lb_log_h);



                if obj.TrainTest
                    train.LB(iter)          = mean(lb_log_h);
                    train.llh(iter)         = mean(llh_s);
                    train.h_theta(iter)     = mean(h_theta_s);
                    train.llp(iter)         = mean(llp_s);
                    [train.nll(iter),tmp_ht]= vol_nll(data.train,mu,obj.P,obj.O,obj.Q,obj.GarchType);
                    train.mse(iter)         = mean((data.train.^2-tmp_ht).^2);

                    test.LB(iter)           = mean(lb_log_h_test);
                    test.llh(iter)          = mean(llh_test_s);
                    test.h_theta(iter)      = mean(h_theta_test_s);
                    test.llp(iter)          = mean(llp_test_s);
                    [test.nll(iter),tmp_ht] = vol_nll(data.test,mu,obj.P,obj.O,obj.Q,obj.GarchType);
                    test.mse(iter)          = mean((data.test.^2-tmp_ht).^2);
                    clear tmp_ht
                end



                % Smooth the lowerbound and store best results
                if iter>window_size
                    LB_smooth(iter-window_size) = mean(LB(iter-window_size:iter));
                    if obj.TrainTest
                        LB_smooth_test(iter-window_size) = mean(LB_test(iter-window_size:iter));
                    end

                    if LB_smooth(iter-window_size)>=max(LB_smooth(1:iter-window_size))
                        mu_best  = mu;
                        Sig_best = Sig;
                        patience = 0;
                        iter_best = iter;
                    else
                        patience = patience + 1;
                    end

                end

                if (patience>obj.MaxPatience)||(iter>obj.MaxIter)
                    stop = true;
                end

                % Display training information
                obj.print_training_info(stop,iter,LB,LB_smooth)

                % Save params at each iteration
                if obj.SaveParams
                    it.mu(iter,:)       = mu;
                    it.S(iter,:)        = Sig(:)';
                    it.par(iter,:)      = obj.itransform(mu);

                    [nll,ht]            = garch_nll_ht(data.train,mu,obj);
                    it.tarch_ll(iter,:) = -nll;
                    it.mse(iter,:)      = garch_mse(data.train.^2,ht);
                    it.llh(iter,1)      = mean(llh_s);
                    it.pri(iter,1)      = mean(pri_s);
                end

            end

            % Store output
            LB_smooth       = LB_smooth(1:(iter-window_size-1));
            LB              = LB(1:iter-1);

            Post.LB0                = LB0;
            Post.LB                 = LB;
            Post.LB_smooth          = LB_smooth;
            [Post.LB_max,Post.LB_indx] = max(LB_smooth);
            Post.mu                 = mu_best;
            Post.tpar               = Post.mu;
            [~,Post.llh,Post.pri]   = hfunc(data.train,Post.mu,obj);
            Post.par                = obj.itransform(mu_best);
            Post.Sig                = Sig_best;
            Post.Sig2               = diag(Post.Sig);

            [Post.tarch_ll,Post.ht] = garch_nll_ht(data.train,Post.mu,obj);

            if obj.SaveParams
                Post.iter = iter;
            end

            if obj.TrainTest
                Post.iter_best = iter_best;
                Post.iter_best  = iter_best;
                Post.train      = train;
                Post.test       = test;
                Post.train.n    = size(data.train,1);
                Post.test.n     = size(data.test,1);

                tmp                     = movmean(train.LB,[window_size 0]);
                Post.train.LB_smooth    = tmp(window_size+1:end-1);
                [Post.train.LB_max,ind] = max(Post.train.LB_smooth);
                tmp                     = movmean(test.LB,[window_size 0]);
                Post.test.LB_smooth     = tmp(window_size+1:end-1);
                Post.test.LB_max        = Post.test.LB_smooth(ind);

                [Post.train.nll_best,Post.train.ht_best]  = vol_nll(data.train,mu_best,obj.P,obj.O,obj.Q,obj.GarchType);
                [Post.test.nll_best,Post.test.ht_best]    = vol_nll(data.test,mu_best,obj.P,obj.O,obj.Q,obj.GarchType);

                Post.train.mse_best      = mean((data.train.^2-Post.train.ht_best).^2);
                Post.test.mse_best       = mean((data.test.^2-Post.test.ht_best).^2);

                Post.perf = [Post.train.LB_max, Post.train.nll_best,    Post.train.mse_best,...
                            Post.test.LB_max,   Post.test.nll_best,     Post.test.mse_best];
            end



            % Plot LB
            if(obj.LBPlot)
                obj.plot_lb(LB_smooth);
            end

            if obj.Verbose >0
                fprintf([obj.Optimizer ', Max. LB : %.4f.\n\n'],Post.LB_max)
            end
        end


    end
end

%% Functions definitions

function[fun] = h_fun

fun = @(data,theta,obj) grad_h_func(data,theta,obj);

    function [h_func,llh,log_prior] = grad_h_func(data,theta,obj)

        % Extract additional settings
        d = length(theta);
        sigma2 = obj.Setting.Prior(2);

        % Extract data
        y = data(:,end);

        % Compute log likelihood
        switch obj.GarchType
            case 'garch'
                llh = -fun_tarch_nll(y,theta,obj.P,obj.O,obj.Q,1);
            case 'egarch'
                llh = -fun_egarch_nll(y,theta,obj.P,obj.O,obj.Q,1);
            case 'figarch'
                llh = -fun_figarch_nll(y,theta,obj.P,obj.Q,0.5,1);
            case 'rgarch'
                llh = -fun_rgarch_nll(data,theta,obj.P,obj.Q,1);
        end

        % Compute log prior
        log_prior = -d/2*log(2*pi)-d/2*log(sigma2)-theta'*theta/sigma2/2;

        % Compute h(theta) = log p(y|theta) + log p(theta)
        h_func = llh + log_prior;

    end

end


function[nll,ht] = garch_nll_ht(data,par,obj)

switch obj.GarchType
    case 'garch'
        [nll,ht]   = fun_tarch_nll(data,par,obj.P,obj.O,obj.Q,1);
    case 'egarch'
        [nll,ht]   = fun_egarch_nll(data,par,obj.P,obj.O,obj.Q,1);
    case 'figarch'
        [nll,ht]   = fun_figarch_nll(data,par,obj.P,obj.Q,0.5,1);
    case 'rgarch'
        [nll,ht]   = fun_rgarch_nll(data,par,obj.P,obj.Q,1);
end

end


function[mse] = garch_mse(y_true,y)
mse = mean((y_true-y).^2);
end