/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.evaluation.functions;

import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import networks.computation.evaluation.functions.Activation;
import networks.computation.evaluation.functions.Aggregation;
import networks.computation.evaluation.values.Value;
import networks.structure.metadata.states.AggregationState;

public class ElementProduct
extends Activation {
    private static final Logger LOG = Logger.getLogger(ElementProduct.class.getName());
    Activation activation;

    @Override
    public String getName() {
        return "DotProduct";
    }

    public ElementProduct(Activation activation) {
        super(activation.evaluation, activation.gradient);
        this.activation = activation;
    }

    @Override
    public Aggregation replaceWithSingleton() {
        LOG.severe("ElementProduct cannot be singleton.");
        return null;
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        Value sum = this.sumInputs(inputs);
        return this.activation.evaluate(sum);
    }

    @Override
    public Value differentiate(List<Value> inputs) {
        Value sum = this.sumInputs(inputs);
        return this.activation.differentiate(sum);
    }

    @Override
    public AggregationState getAggregationState() {
        return new AggregationState.CumulationState(this);
    }

    private Value sumInputs(List<Value> inputs) {
        int[] size = inputs.get(0).size();
        for (int i = 0; i < inputs.size(); ++i) {
            if (Arrays.equals(size, inputs.get(i).size())) continue;
            LOG.severe("ScalarProduct dimensions mismatch!");
            return null;
        }
        Value sum = inputs.get(0).clone().zero();
        for (Value input : inputs) {
            sum.incrementBy(input);
        }
        return sum;
    }
}

