/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building;

import com.sun.istack.internal.NotNull;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.constructs.example.ValuedFact;
import cz.cvut.fel.ida.logic.constructs.template.components.BodyAtom;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuralBuilder;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuronMaps;
import cz.cvut.fel.ida.neural.networks.structure.building.builders.StatesBuilder;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralSets;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.AggregationState;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomFact;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.FactNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.NegationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedAtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedRuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.types.TopologicNetwork;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.LinkedMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.setup.Settings;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.InputMismatchException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

public class NeuralNetBuilder {
    private static final Logger LOG = Logger.getLogger(NeuralNetBuilder.class.getName());
    public NeuralBuilder neuralBuilder;
    private Settings settings;

    public NeuralNetBuilder(Settings settings, NeuralBuilder neuralBuilder) {
        this.neuralBuilder = neuralBuilder;
        this.settings = settings;
    }

    public NeuralNetBuilder(Settings settings) {
        this.neuralBuilder = new NeuralBuilder(settings);
        this.settings = settings;
    }

    public void loadNeuronsFromRules(Literal head, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>> rules, NeuralSets createdNeurons) {
        NeuronMaps neuronMaps = this.neuralBuilder.neuronFactory.neuronMaps;
        boolean newAtomNeuron = false;
        boolean weightedAtomNeuron = false;
        AtomNeurons headAtomNeuron = neuronMaps.atomNeurons.get(head);
        if (headAtomNeuron == null) {
            newAtomNeuron = true;
            Iterator<Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>>> iterator = rules.entrySet().iterator();
            Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>> liftedRule = null;
            while (iterator.hasNext()) {
                liftedRule = iterator.next();
                if (!head.equals(((GroundRule)liftedRule.getValue().iterator().next()).groundHead)) {
                    LOG.severe("Ground heads corresponding to the same atom neuron are different!");
                }
                if (liftedRule.getKey().weightedRule.getWeight().equals(Weight.unitWeight)) continue;
                weightedAtomNeuron = true;
            }
            if (weightedAtomNeuron) {
                headAtomNeuron = this.neuralBuilder.neuronFactory.createWeightedAtomNeuron(liftedRule.getKey().weightedRule.getHead(), head);
                createdNeurons.weightedAtomNeurons.add((WeightedAtomNeuron)headAtomNeuron);
            } else {
                headAtomNeuron = this.neuralBuilder.neuronFactory.createUnweightedAtomNeuron(liftedRule.getKey().weightedRule.getHead(), head);
                createdNeurons.atomNeurons.add((AtomNeuron)headAtomNeuron);
            }
            if (headAtomNeuron.getComputationView(0).getAggregationState() instanceof AggregationState.CrossProducState) {
                neuronMaps.containsCrossproduct = true;
            }
        } else {
            headAtomNeuron.setShared(true);
            if (rules.entrySet().size() > 0) {
                NeuronMapping inputMapping;
                if (headAtomNeuron instanceof WeightedNeuron) {
                    weightedAtomNeuron = true;
                    inputMapping = (WeightedNeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                    if (inputMapping != null) {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new WeightedNeuronMapping(inputMapping));
                    } else {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new WeightedNeuronMapping(headAtomNeuron.getInputs(), ((WeightedNeuron)((Object)headAtomNeuron)).getWeights()));
                    }
                } else {
                    inputMapping = (NeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                    if (inputMapping != null) {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new NeuronMapping(inputMapping));
                    } else {
                        neuronMaps.extraInputMapping.put(headAtomNeuron, new NeuronMapping(headAtomNeuron.getInputs()));
                    }
                }
            }
        }
        for (Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>> rules2groundings : rules.entrySet()) {
            NeuronMapping inputMapping;
            boolean newAggNeuron = false;
            AggregationNeuron aggNeuron = neuronMaps.aggNeurons.get(rules2groundings.getKey());
            if (aggNeuron == null) {
                newAggNeuron = true;
                aggNeuron = this.neuralBuilder.neuronFactory.createAggNeuron(rules2groundings.getKey());
                if (aggNeuron.getComputationView(0).getAggregationState().getInputMask() != null) {
                    neuronMaps.containsMasking = true;
                }
                createdNeurons.aggNeurons.add(aggNeuron);
            } else {
                aggNeuron.isShared = true;
                if (rules2groundings.getValue().size() > 0) {
                    inputMapping = (NeuronMapping)neuronMaps.extraInputMapping.get(aggNeuron);
                    if (inputMapping != null) {
                        neuronMaps.extraInputMapping.put(aggNeuron, new NeuronMapping(inputMapping));
                    } else {
                        neuronMaps.extraInputMapping.put(aggNeuron, new NeuronMapping(aggNeuron.getInputs()));
                    }
                }
            }
            if (newAtomNeuron) {
                if (weightedAtomNeuron) {
                    ((WeightedNeuron)((Object)headAtomNeuron)).addInput(aggNeuron, rules2groundings.getKey().weightedRule.getWeight());
                } else {
                    headAtomNeuron.addInput(aggNeuron);
                }
            } else {
                LOG.info("Warning-  modifying previous state - Creating input overmapping for this Atom neuron: " + headAtomNeuron);
                inputMapping = (WeightedNeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                inputMapping.addLink(aggNeuron);
                ((WeightedNeuronMapping)inputMapping).addWeight(rules2groundings.getKey().weightedRule.getWeight());
            }
            for (GroundRule grounding : rules2groundings.getValue()) {
                RuleNeurons ruleNeuron = neuronMaps.ruleNeurons.get(grounding);
                if (ruleNeuron == null) {
                    if (grounding.weightedRule.detectWeights()) {
                        ruleNeuron = this.neuralBuilder.neuronFactory.createWeightedRuleNeuron(grounding);
                        createdNeurons.weightedRuleNeurons.add((WeightedRuleNeuron)ruleNeuron);
                    } else {
                        ruleNeuron = this.neuralBuilder.neuronFactory.createRuleNeuron(grounding);
                        createdNeurons.ruleNeurons.add((RuleNeuron)ruleNeuron);
                    }
                    if (ruleNeuron.getComputationView(0).getAggregationState() instanceof AggregationState.CrossProducState) {
                        neuronMaps.containsCrossproduct = true;
                    }
                } else {
                    LOG.severe("Inconsistency - Specific rule neuron already contained in neuronmap!! This should never happen...");
                }
                if (newAggNeuron) {
                    aggNeuron.addInput(ruleNeuron);
                    continue;
                }
                LOG.info("Warning-  modifying previous state - Creating input overmapping for this Agg neuron: " + aggNeuron);
                NeuronMapping inputMapping2 = (NeuronMapping)neuronMaps.extraInputMapping.get(headAtomNeuron);
                inputMapping2.addLink(ruleNeuron);
            }
        }
    }

    public void loadNeuronsFromFacts(Map<Literal, ValuedFact> groundFacts, NeuralSets createdNeurons) {
        for (Map.Entry<Literal, ValuedFact> factEntry : groundFacts.entrySet()) {
            this.neuralBuilder.neuronFactory.createFactNeuron(factEntry.getValue());
        }
        createdNeurons.factNeurons.addAll(this.neuralBuilder.neuronFactory.neuronMaps.factNeurons.values());
        groundFacts.clear();
    }

    @NotNull
    public void connectAllNeurons(NeuralSets createdNeurons) {
        NeuronMaps neuronMaps = this.neuralBuilder.neuronFactory.neuronMaps;
        for (Map.Entry<GroundRule, RuleNeurons> entry : neuronMaps.ruleNeurons.entrySet()) {
            RuleNeurons ruleNeuron = entry.getValue();
            if (ruleNeuron.inputCount() == entry.getKey().weightedRule.getBody().size()) continue;
            for (int i = 0; i < entry.getKey().groundBody.length; ++i) {
                BodyAtom liftedBodyAtom = entry.getKey().weightedRule.getBody().get(i);
                Literal literal = entry.getKey().groundBody[i];
                Weight weight = liftedBodyAtom.getConjunctWeight();
                AtomFact input = neuronMaps.atomNeurons.get(literal);
                if (input == null) {
                    FactNeuron factNeuron = neuronMaps.factNeurons.get(literal);
                    if (factNeuron == null) {
                        LOG.severe("Error: no input found for this neuron!!: " + literal);
                    }
                    input = factNeuron;
                }
                if (liftedBodyAtom.isNegated()) {
                    NegationNeuron negationNeuron = this.neuralBuilder.neuronFactory.createNegationNeuron(input, liftedBodyAtom.getNegationActivation());
                    input = negationNeuron;
                }
                if (ruleNeuron instanceof WeightedNeuron) {
                    ((WeightedNeuron)((Object)ruleNeuron)).addInput(input, weight);
                    continue;
                }
                ((RuleNeuron)ruleNeuron).addInput(input);
            }
        }
    }

    public DetailedNetwork finalizeStoredNetwork(String id, NeuralSets createdNeurons, List<Literal> queryMatchingLiterals) throws RuntimeException {
        ArrayList<AtomNeurons> queryNeurons = null;
        if (queryMatchingLiterals != null) {
            queryNeurons = new ArrayList<AtomNeurons>();
            for (Literal queryMatchingLiteral : queryMatchingLiterals) {
                AtomNeurons qn = this.neuralBuilder.neuronFactory.neuronMaps.atomNeurons.get(queryMatchingLiteral);
                if (qn == null) {
                    String err = "Query: " + queryMatchingLiteral + " was not matched anywhere in the ground network - Cannot calculate its output!";
                    LOG.severe(err);
                    LOG.warning(" -> This most likely means that the template is wrong as there is no proof-path from the example to the query");
                    LOG.warning("   -> Check all the predicate signatures etc. to make sure the template matches your examples and that there is at least 1 inference chain to the query");
                    throw new InputMismatchException(err);
                }
                queryNeurons.add(qn);
            }
        }
        DetailedNetwork neuralNetwork = this.neuralBuilder.networkFactory.createDetailedNetwork(queryNeurons, createdNeurons, id, this.neuralBuilder.neuronFactory.neuronMaps.extraInputMapping);
        LOG.fine("DetailedNetwork created.");
        StatesBuilder statesBuilder = this.neuralBuilder.statesBuilder;
        statesBuilder.inferValues(neuralNetwork);
        LOG.fine("Neuron dimensions inferred.");
        if (this.settings.dropoutRate > 0.0) {
            statesBuilder.setupDropoutStates(neuralNetwork);
        }
        if (this.getNeuronMaps().containsCrossproduct) {
            neuralNetwork.containsCrossProducts = true;
        }
        if (this.getNeuronMaps().containsMasking) {
            neuralNetwork.containsInputMasking = true;
        }
        if (neuralNetwork.extraInputMapping != null && !neuralNetwork.extraInputMapping.isEmpty()) {
            statesBuilder.addLinkedInputsToNetworkStates(neuralNetwork);
        }
        if (this.settings.parentCounting || this.settings.neuralNetsPostProcessing) {
            neuralNetwork.outputMapping = this.calculateOutputs(neuralNetwork);
            if (this.settings.parentCounting) {
                statesBuilder.setupParentStateNumbers(neuralNetwork);
            }
        }
        if (this.settings.parallelTraining) {
            int sharedNeuronsCount = statesBuilder.makeSharedStatesRecursively(neuralNetwork);
            LOG.fine("Shared neurons marked.");
            neuralNetwork.setSharedNeuronsCount(sharedNeuronsCount);
        }
        return neuralNetwork;
    }

    public Map<BaseNeuron, LinkedMapping> calculateOutputs(TopologicNetwork<State.Structure> network) {
        HashMap<BaseNeuron, LinkedMapping> outputMapping = new HashMap<BaseNeuron, LinkedMapping>();
        for (BaseNeuron<Neurons, State.Neural> parent : network.allNeuronsTopologic) {
            BaseNeuron child;
            Iterator<Neurons> inputs = network.getInputs(parent);
            while (inputs.hasNext() && (child = (BaseNeuron)inputs.next()) != null) {
                LinkedMapping parentMapping = outputMapping.computeIfAbsent(child, f -> new NeuronMapping());
                parentMapping.addLink(parent);
            }
        }
        return outputMapping;
    }

    public NeuronMaps getNeuronMaps() {
        return this.neuralBuilder.neuronFactory.neuronMaps;
    }

    public void setNeuronMaps(NeuronMaps neuronMaps) {
        this.neuralBuilder.neuronFactory.neuronMaps = neuronMaps;
    }
}

