/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building;

import cz.cvut.fel.ida.algebra.functions.Activation;
import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.functions.CrossProduct;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.constructs.building.factories.WeightFactory;
import cz.cvut.fel.ida.logic.constructs.example.ValuedFact;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.logic.constructs.template.components.HeadAtom;
import cz.cvut.fel.ida.logic.constructs.template.components.WeightedRule;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuronMaps;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.States;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomFact;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.FactNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.NegationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedAtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedRuleNeuron;
import cz.cvut.fel.ida.setup.Settings;
import java.util.logging.Logger;

public class NeuronFactory {
    private static final Logger LOG = Logger.getLogger(NeuronFactory.class.getName());
    private WeightFactory weightFactory;
    Settings settings;
    private int counter = 0;
    private Weight atomOffset;
    private Weight ruleOffset;
    private Value defaultFactValue = Value.ONE;
    public NeuronMaps neuronMaps;

    public NeuronFactory(WeightFactory weightFactory, Settings settings) {
        this.weightFactory = weightFactory;
        this.settings = settings;
        this.atomOffset = weightFactory.construct("fixedAtomOffset", new ScalarValue(settings.defaultAtomNeuronOffset), true, true);
        this.ruleOffset = weightFactory.construct("fixedRuleOffset", new ScalarValue(settings.defaultRuleNeuronOffset), true, true);
        if (settings.defaultFactValue != 1.0) {
            this.defaultFactValue = new ScalarValue(settings.defaultFactValue);
        }
    }

    public WeightedAtomNeuron createWeightedAtomNeuron(HeadAtom head, Literal groundHead) {
        Activation activation = head.getActivation() != null ? head.getActivation() : Activation.getActivationFunction(this.settings.atomNeuronActivation);
        State.Neural.Computation state = State.createBaseState(this.settings, activation);
        Weight offset = head.getOffset();
        if (offset == null) {
            if (this.settings.defaultAtomOffsetsLearnable) {
                offset = this.settings.defaultAtomNeuronOffset != 0.0 ? this.weightFactory.construct(new ScalarValue(this.settings.defaultAtomNeuronOffset), false, true) : this.weightFactory.construct(new ScalarValue(this.settings.defaultAtomNeuronOffset), false, false);
            } else if (this.settings.defaultAtomNeuronOffset != 0.0) {
                offset = this.atomOffset;
            }
        }
        WeightedAtomNeuron<State.Neural.Computation> atomNeuron = new WeightedAtomNeuron<State.Neural.Computation>(groundHead.toString(), offset, this.counter++, state);
        this.neuronMaps.atomNeurons.put(groundHead, atomNeuron);
        LOG.finest(() -> "Created atom neuron: " + atomNeuron);
        return atomNeuron;
    }

    public AtomNeuron createUnweightedAtomNeuron(HeadAtom head, Literal groundHead) {
        Activation activation = head.getActivation() != null ? head.getActivation() : Activation.getActivationFunction(this.settings.atomNeuronActivation);
        State.Neural.Computation state = State.createBaseState(this.settings, activation);
        AtomNeuron<State.Neural.Computation> atomNeuron = new AtomNeuron<State.Neural.Computation>(groundHead.toString(), this.counter++, state);
        this.neuronMaps.atomNeurons.put(groundHead, atomNeuron);
        LOG.finest(() -> "Created atom neuron: " + atomNeuron);
        return atomNeuron;
    }

    public AggregationNeuron createAggNeuron(GroundHeadRule groundHeadRule) {
        WeightedRule weightedRule = groundHeadRule.weightedRule;
        Aggregation aggregation = weightedRule.getAggregationFcn() != null ? weightedRule.getAggregationFcn() : Aggregation.getAggregation(this.settings.aggNeuronActivation);
        State.Neural.Computation state = State.createBaseState(this.settings, aggregation);
        AggregationNeuron<State.Neural.Computation> aggregationNeuron = new AggregationNeuron<State.Neural.Computation>(this.settings.fullAggNeuronStrings ? groundHeadRule.toFullString() : weightedRule.getOriginalString(), this.counter++, state);
        this.neuronMaps.aggNeurons.put(groundHeadRule, aggregationNeuron);
        LOG.finest(() -> "Created aggregation neuron: " + aggregationNeuron);
        return aggregationNeuron;
    }

    public RuleNeuron createRuleNeuron(GroundRule groundRule) {
        Activation activation;
        WeightedRule weightedRule = groundRule.weightedRule;
        Activation activation2 = activation = weightedRule.getActivationFcn() != null ? weightedRule.getActivationFcn() : Activation.getActivationFunction(this.settings.ruleNeuronActivation);
        if (weightedRule.isCrossProduct()) {
            activation = new CrossProduct(activation);
        }
        State.Neural.Computation state = State.createBaseState(this.settings, activation);
        RuleNeuron<State.Neural.Computation> ruleNeuron = new RuleNeuron<State.Neural.Computation>(this.settings.fullRuleNeuronStrings ? groundRule.toFullString() : weightedRule.getOriginalString(), this.counter++, state);
        this.neuronMaps.ruleNeurons.put(groundRule, ruleNeuron);
        LOG.finest(() -> "Created rule neuron: " + ruleNeuron);
        return ruleNeuron;
    }

    public WeightedRuleNeuron createWeightedRuleNeuron(GroundRule groundRule) {
        Weight offset;
        Activation activation;
        WeightedRule weightedRule = groundRule.weightedRule;
        Activation activation2 = activation = weightedRule.getActivationFcn() != null ? weightedRule.getActivationFcn() : Activation.getActivationFunction(this.settings.ruleNeuronActivation);
        if (weightedRule.isCrossProduct()) {
            activation = new CrossProduct(activation);
        }
        if ((offset = weightedRule.getOffset()) == null) {
            if (this.settings.defaultRuleOffsetsLearnable) {
                offset = this.settings.defaultRuleNeuronOffset != 0.0 ? this.weightFactory.construct(new ScalarValue(this.settings.defaultRuleNeuronOffset), false, true) : this.weightFactory.construct(new ScalarValue(this.settings.defaultRuleNeuronOffset), false, false);
            } else if (this.settings.defaultRuleNeuronOffset != 0.0) {
                offset = this.atomOffset;
            }
        }
        State.Neural.Computation state = State.createBaseState(this.settings, activation);
        WeightedRuleNeuron<State.Neural.Computation> weightedRuleNeuron = new WeightedRuleNeuron<State.Neural.Computation>(this.settings.fullRuleNeuronStrings ? groundRule.toFullString() : weightedRule.getOriginalString(), offset, this.counter++, state);
        this.neuronMaps.ruleNeurons.put(groundRule, weightedRuleNeuron);
        LOG.finest(() -> "Created weightedRule neuron: " + weightedRuleNeuron);
        return weightedRuleNeuron;
    }

    public FactNeuron createFactNeuron(ValuedFact fact) {
        FactNeuron result = this.neuronMaps.factNeurons.get(fact.literal);
        if (result == null) {
            States.SimpleValue simpleValue = new States.SimpleValue(fact.getValue() == null ? this.defaultFactValue : fact.getValue());
            FactNeuron factNeuron = new FactNeuron(fact.toString(), fact.getOffset(), this.counter++, simpleValue);
            this.neuronMaps.factNeurons.put(fact.literal, factNeuron);
            LOG.finest(() -> "Created fact neuron: " + factNeuron);
            return factNeuron;
        }
        return result;
    }

    public NegationNeuron createNegationNeuron(AtomFact atomFact, Activation negation) {
        Activation activation = negation != null ? negation : Activation.getActivationFunction(this.settings.negation);
        State.Neural.Computation state = State.createBaseState(this.settings, activation);
        NegationNeuron<State.Neural.Computation> negationNeuron = new NegationNeuron<State.Neural.Computation>(atomFact, this.counter++, state);
        this.neuronMaps.negationNeurons.add(negationNeuron);
        LOG.finest(() -> "Created negation neuron: " + negationNeuron);
        return negationNeuron;
    }
}

