/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.evaluation.functions;

import java.util.List;
import java.util.function.Function;
import java.util.logging.Logger;
import networks.computation.evaluation.functions.Aggregation;
import networks.computation.evaluation.functions.CrossProduct;
import networks.computation.evaluation.functions.specific.Identity;
import networks.computation.evaluation.functions.specific.LukasiewiczSigmoid;
import networks.computation.evaluation.functions.specific.ReLu;
import networks.computation.evaluation.functions.specific.Sigmoid;
import networks.computation.evaluation.functions.specific.Signum;
import networks.computation.evaluation.functions.specific.Tanh;
import networks.computation.evaluation.values.Value;
import settings.Settings;

public abstract class Activation
extends Aggregation {
    private static final Logger LOG = Logger.getLogger(Activation.class.getName());
    Function<Double, Double> evaluation;
    Function<Double, Double> gradient;

    protected Activation(Function<Double, Double> evaluation, Function<Double, Double> gradient) {
        this.evaluation = evaluation;
        this.gradient = gradient;
    }

    @Override
    public boolean isInputSymmetric() {
        return true;
    }

    public Value evaluate(Value summedInputs) {
        return summedInputs.apply(this.evaluation);
    }

    public Value differentiate(Value summedInputs) {
        return summedInputs.apply(this.gradient);
    }

    @Override
    public Value evaluate(List<Value> inputs) {
        Value sum = inputs.get(0).clone();
        int len = inputs.size();
        for (int i = 1; i < len; ++i) {
            sum.plus(inputs.get(i));
        }
        return this.evaluate(sum);
    }

    @Override
    public Value differentiate(List<Value> inputs) {
        Value sum = inputs.get(0).clone();
        int len = inputs.size();
        for (int i = 1; i < len; ++i) {
            sum.plus(inputs.get(i));
        }
        return this.differentiate(sum);
    }

    public static Activation getActivationFunction(Settings.ActivationFcn activationFcn) {
        switch (activationFcn) {
            case SIGMOID: {
                return Singletons.sigmoid;
            }
            case TANH: {
                return Singletons.tanh;
            }
            case SIGNUM: {
                return Singletons.signum;
            }
            case IDENTITY: {
                return Singletons.identity;
            }
            case RELU: {
                return Singletons.relu;
            }
            case LUKASIEWICZ: {
                return Singletons.lukasiewiczSigmoid;
            }
        }
        LOG.severe("Unimplemented activation function");
        return null;
    }

    public static Aggregation parseActivation(String agg) {
        switch (agg) {
            case "sigmoid": {
                return Singletons.sigmoid;
            }
            case "tanh": {
                return Singletons.tanh;
            }
            case "signum": {
                return Singletons.signum;
            }
            case "relu": {
                return Singletons.relu;
            }
            case "identity": {
                return Singletons.identity;
            }
            case "lukasiewicz": {
                return Singletons.lukasiewiczSigmoid;
            }
            case "crossproduct": {
                return new CrossProduct(Singletons.lukasiewiczSigmoid);
            }
        }
        if (agg.startsWith("crossproduct-")) {
            String inner = agg.substring(agg.indexOf("-") + 1, agg.length());
            Aggregation innerActivation = Activation.parseActivation(inner);
            return new CrossProduct((Activation)innerActivation);
        }
        return null;
    }

    public static class Singletons {
        public static LukasiewiczSigmoid lukasiewiczSigmoid = new LukasiewiczSigmoid();
        public static Sigmoid sigmoid = new Sigmoid();
        public static Signum signum = new Signum();
        public static ReLu relu = new ReLu();
        public static Identity identity = new Identity();
        public static Tanh tanh = new Tanh();
    }
}

