/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training.optimizers;

import java.util.List;
import java.util.logging.Logger;
import networks.computation.evaluation.values.ScalarValue;
import networks.computation.evaluation.values.Value;
import networks.computation.iteration.visitors.weights.WeightUpdater;
import networks.computation.training.NeuralModel;
import networks.computation.training.optimizers.Optimizer;
import networks.structure.components.weights.Weight;
import settings.Settings;

public class Adam
implements Optimizer {
    private static final Logger LOG = Logger.getLogger(Adam.class.getName());
    public Value learningRate;
    public ScalarValue beta1;
    public ScalarValue beta2;
    public ScalarValue epsilon;
    private long iterrationCount = 1L;
    private ScalarValue minusOne = new ScalarValue(-1.0);

    public Adam(Value learningRate) {
        this(learningRate, 0.9, 0.999, 1.0E-8);
    }

    public Adam(Value learningRate, double i_beta1, double i_beta2, double i_epsilon) {
        this.learningRate = learningRate;
        this.beta1 = new ScalarValue(i_beta1);
        this.beta2 = new ScalarValue(i_beta2);
        this.epsilon = new ScalarValue(i_epsilon);
    }

    @Override
    public void performGradientStep(NeuralModel neuralModel, WeightUpdater weightUpdater) {
        this.performGradientStep(neuralModel.weights, weightUpdater.weightUpdates);
    }

    @Override
    public void performGradientStep(List<Weight> weights, Value[] weightUpdates) {
        ScalarValue fix1 = new ScalarValue(1.0 / (1.0 - Math.pow(this.beta1.value, this.iterrationCount)));
        ScalarValue fix2 = new ScalarValue(1.0 / (1.0 - Math.pow(this.beta2.value, this.iterrationCount)));
        for (Weight weight : weights) {
            if (weight.isFixed || weight.index < 0) continue;
            Value gradient = weightUpdates[weight.index].times((Value)this.minusOne);
            weight.velocity = this.beta2.times(weight.velocity).plus(Value.ONE.minus((Value)this.beta2).times(gradient.elementTimes(gradient)));
            weight.momentum = this.beta1.times(weight.momentum).plus(Value.ONE.minus((Value)this.beta1).times(gradient));
            Value v_corr = weight.momentum.times((Value)fix1);
            Value s_corr = weight.velocity.times((Value)fix2);
            Value divider = s_corr.apply(Math::sqrt).plus((Value)this.epsilon).apply(val -> -1.0 / val);
            Value update = v_corr.elementTimes(divider);
            weight.value.incrementBy(update.times(this.learningRate));
        }
        ++this.iterrationCount;
    }

    @Override
    public void restart(Settings settings) {
        this.iterrationCount = 1L;
    }
}

