/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.weights;

import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.weights.WeightVisitor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

public class WeightUpdater
implements WeightVisitor {
    private static final Logger LOG = Logger.getLogger(WeightUpdater.class.getName());
    public Value[] weightUpdates;
    public List<Weight> updatedWeightsOnly;

    public WeightUpdater(List<Weight> learnableWeights, int maxWeightIndex) {
        this.check4mistakes(learnableWeights, maxWeightIndex);
        this.weightUpdates = new Value[maxWeightIndex + 1];
        this.updatedWeightsOnly = new ArrayList<Weight>(maxWeightIndex + 1);
    }

    private void check4mistakes(List<Weight> learnableWeights, int maxWeightIndex) {
        if (maxWeightIndex < learnableWeights.size() - 1) {
            LOG.severe("Weight indices are off (there are more learnable weight than all weights?)!!");
        }
        boolean[] duplicate = new boolean[maxWeightIndex + 1];
        for (Weight weight : learnableWeights) {
            int index = weight.index;
            if (index > maxWeightIndex) {
                LOG.severe("Weight index exceeding number of all extracted allWeights!");
            }
            if (!weight.isLearnable.booleanValue()) {
                LOG.severe("Fixed weights leaking through into WeightUpdater!! (should have been filtered before)");
            }
            if (duplicate[index]) {
                LOG.severe("Weight index seen twice! Input weight list is not unique! Some weight will try to be updated twice!");
            }
            duplicate[index] = true;
        }
    }

    @Override
    public void visit(Weight weight, Value value) {
        if (weight.isLearnable.booleanValue()) {
            int index = weight.index;
            Value weightUpdate = this.weightUpdates[index];
            if (weightUpdate != null) {
                weightUpdate.incrementBy(value);
            } else {
                this.weightUpdates[index] = value.clone();
                this.updatedWeightsOnly.add(weight);
            }
        }
    }

    public void clearUpdates() {
        Arrays.fill(this.weightUpdates, null);
        this.updatedWeightsOnly.clear();
    }
}

