/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.building;

import com.sun.istack.internal.NotNull;
import cz.cvut.fel.ida.logic.Clause;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.constructs.building.factories.WeightFactory;
import cz.cvut.fel.ida.logic.constructs.example.LogicSample;
import cz.cvut.fel.ida.logic.constructs.example.QueryAtom;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.logic.grounding.GroundTemplate;
import cz.cvut.fel.ida.logic.grounding.GroundingSample;
import cz.cvut.fel.ida.logic.subsumption.Matching;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuralNetBuilder;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuralProcessingSample;
import cz.cvut.fel.ida.neural.networks.structure.building.NeuronMaps;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralSets;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeurons;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import cz.cvut.fel.ida.utils.generic.Timing;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class Neuralizer
implements Exportable {
    private static final Logger LOG = Logger.getLogger(Neuralizer.class.getName());
    private transient Settings settings;
    public transient NeuralNetBuilder neuralNetBuilder;
    public NeuralSets.NeuronCounter neuronCounts;
    int queryNeuronsCreated;
    int groundRulesProcessed;
    int networksCreated;
    public Timing timing;

    public Neuralizer(Settings settings) {
        this.settings = settings;
        this.neuralNetBuilder = new NeuralNetBuilder(settings);
        this.timing = new Timing();
    }

    public Neuralizer(Settings settings, WeightFactory weightFactory) {
        this(settings);
        this.neuralNetBuilder.neuralBuilder.weightFactory = weightFactory;
    }

    public List<NeuralProcessingSample> neuralize(GroundTemplate groundTemplate, List<GroundingSample> samples) throws RuntimeException {
        DetailedNetwork neuralNetwork;
        this.timing.tic();
        ++this.networksCreated;
        GroundingSample groundingSample = samples.get(0);
        NeuronMaps neuronMaps = (NeuronMaps)groundingSample.groundingWrap.getNeuronMaps();
        if (neuronMaps == null) {
            NeuronMaps finalNeuronMaps = neuronMaps = new NeuronMaps(groundingSample.groundingWrap.getGroundTemplate().groundRules, groundingSample.groundingWrap.getGroundTemplate().groundFacts);
            samples.forEach(s -> s.groundingWrap.setNeuronMaps(finalNeuronMaps));
        }
        this.neuralNetBuilder.setNeuronMaps(neuronMaps);
        NeuralSets createdNeurons = new NeuralSets();
        ArrayList<Literal> queryMatchingLiterals = new ArrayList<Literal>();
        ArrayList<LogicSample> origSamples = new ArrayList<LogicSample>();
        for (LogicSample logicSample : samples) {
            List<Literal> foundQueries = this.getQueryMatchingLiterals((QueryAtom)logicSample.query, groundTemplate.groundRules);
            for (Literal foundQuery : foundQueries) {
                queryMatchingLiterals.add(foundQuery);
                origSamples.add(logicSample);
            }
        }
        if (this.settings.forceFullNetworks) {
            neuralNetwork = this.blindNeuralization(groundTemplate, neuronMaps, createdNeurons);
        } else {
            this.neuralNetBuilder = this.loadAllNeuronsStartingFromQueryLiterals(groundTemplate, queryMatchingLiterals, neuronMaps, createdNeurons);
            neuralNetwork = this.getDetailedNetwork(neuronMaps, createdNeurons, groundTemplate, queryMatchingLiterals);
        }
        ArrayList<NeuralProcessingSample> arrayList = new ArrayList<NeuralProcessingSample>();
        for (int i = 0; i < queryMatchingLiterals.size(); ++i) {
            LogicSample logicSample = (LogicSample)origSamples.get(i);
            QueryAtom queryAtom = (QueryAtom)logicSample.query;
            AtomNeurons atomNeuron = this.neuralNetBuilder.getNeuronMaps().atomNeurons.get(queryMatchingLiterals.get(i));
            if (atomNeuron == null) {
                LOG.severe("No inference network created for " + queryAtom);
            }
            QueryNeuron queryNeuron = new QueryNeuron(queryAtom.ID, queryAtom.position, queryAtom.importance, atomNeuron, neuralNetwork);
            NeuralProcessingSample neuralProcessingSample = new NeuralProcessingSample(logicSample.target, queryNeuron, logicSample.type);
            arrayList.add(neuralProcessingSample);
        }
        this.neuronCounts = createdNeurons.getCounts();
        this.timing.toc();
        return arrayList;
    }

    public List<NeuralProcessingSample> neuralize(GroundingSample groundingSample) throws RuntimeException {
        this.timing.tic();
        ++this.networksCreated;
        NeuronMaps neuronMaps = (NeuronMaps)groundingSample.groundingWrap.getNeuronMaps();
        if (neuronMaps == null) {
            neuronMaps = new NeuronMaps(groundingSample.groundingWrap.getGroundTemplate().groundRules, groundingSample.groundingWrap.getGroundTemplate().groundFacts);
            groundingSample.groundingWrap.setNeuronMaps(neuronMaps);
        }
        this.neuralNetBuilder.setNeuronMaps(neuronMaps);
        NeuralSets createdNeurons = new NeuralSets();
        List<QueryNeuron> queryNeurons = this.supervisedNeuralization(groundingSample, neuronMaps, createdNeurons);
        this.queryNeuronsCreated += queryNeurons.size();
        if (queryNeurons.isEmpty()) {
            LOG.severe("No inference network created for " + groundingSample.query);
        }
        List<NeuralProcessingSample> samples = queryNeurons.stream().map(queryNeuron -> new NeuralProcessingSample(groundingSample.target, (QueryNeuron)queryNeuron, groundingSample.type)).collect(Collectors.toList());
        this.neuronCounts = createdNeurons.getCounts();
        this.timing.toc();
        return samples;
    }

    private List<QueryNeuron> supervisedNeuralization(GroundingSample groundingSample, NeuronMaps neuronMaps, NeuralSets createdNeurons) throws RuntimeException {
        DetailedNetwork neuralNetwork;
        QueryAtom queryAtom = (QueryAtom)groundingSample.query;
        GroundTemplate groundTemplate = groundingSample.groundingWrap.getGroundTemplate();
        List<Literal> queryMatchingLiterals = this.getQueryMatchingLiterals(queryAtom, groundTemplate.groundRules);
        if (queryMatchingLiterals.isEmpty()) {
            LOG.severe("Query not matched anywhere in the template:" + queryAtom);
            System.exit(5);
        }
        LOG.finer("Obtained QueryMatchingLiterals: " + queryMatchingLiterals);
        if (this.settings.forceFullNetworks) {
            neuralNetwork = this.blindNeuralization(groundTemplate, neuronMaps, createdNeurons);
        } else {
            this.neuralNetBuilder = this.loadAllNeuronsStartingFromQueryLiterals(groundTemplate, queryMatchingLiterals, neuronMaps, createdNeurons);
            neuralNetwork = this.getDetailedNetwork(neuronMaps, createdNeurons, groundTemplate, queryMatchingLiterals);
        }
        return this.getQueryNeurons(queryAtom, this.neuralNetBuilder.getNeuronMaps(), neuralNetwork, queryMatchingLiterals);
    }

    private DetailedNetwork blindNeuralization(GroundTemplate groundTemplate, NeuronMaps neuronMaps, NeuralSets currentNeuralSets) throws RuntimeException {
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : neuronMaps.groundRules.entrySet()) {
            this.neuralNetBuilder.loadNeuronsFromRules(entry.getKey(), entry.getValue(), currentNeuralSets);
        }
        neuronMaps.groundRules.clear();
        return this.getDetailedNetwork(neuronMaps, currentNeuralSets, groundTemplate, null);
    }

    private DetailedNetwork getDetailedNetwork(NeuronMaps neuronMaps, NeuralSets createdNeurons, GroundTemplate groundTemplate, List<Literal> queryMatchingLiterals) throws RuntimeException {
        if (this.neuralNetBuilder.neuralBuilder.neuronFactory.neuronMaps.factNeurons.isEmpty() || this.settings.groundingMode == Settings.GroundingMode.SEQUENTIAL) {
            this.neuralNetBuilder.loadNeuronsFromFacts(neuronMaps.groundFacts, createdNeurons);
        }
        LOG.fine("Neurons created: " + this.neuralNetBuilder.getNeuronMaps());
        this.neuralNetBuilder.connectAllNeurons(createdNeurons);
        LOG.fine("All neurons connected.");
        DetailedNetwork neuralNetwork = this.neuralNetBuilder.finalizeStoredNetwork(groundTemplate.getName(), createdNeurons, queryMatchingLiterals);
        LOG.fine("Final neural network created: " + neuralNetwork);
        return neuralNetwork;
    }

    private void recursiveNeuronsCreation(@NotNull Literal literal, Set<Literal> closedSet, NeuronMaps neuronMaps, NeuralSets currentNeuralSets) {
        if (closedSet.contains(literal)) {
            return;
        }
        closedSet.add(literal);
        LinkedHashMap ruleMap = (LinkedHashMap)neuronMaps.groundRules.remove(literal);
        if (ruleMap != null) {
            this.neuralNetBuilder.loadNeuronsFromRules(literal, ruleMap, currentNeuralSets);
            ++this.groundRulesProcessed;
            for (LinkedHashSet groundings : ruleMap.values()) {
                for (GroundRule grounding : groundings) {
                    for (Literal bodyAtom : grounding.groundBody) {
                        this.recursiveNeuronsCreation(bodyAtom, closedSet, neuronMaps, currentNeuralSets);
                    }
                }
            }
        }
    }

    @NotNull
    protected List<Literal> getQueryMatchingLiterals(QueryAtom queryAtom, @NotNull LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> groundRules) {
        if (!queryAtom.headAtom.literal.containsVariable()) {
            ArrayList<Literal> queries = new ArrayList<Literal>();
            queries.add(queryAtom.headAtom.literal);
            return queries;
        }
        Matching matching = new Matching();
        ArrayList<Literal> queryLiterals = new ArrayList<Literal>();
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : groundRules.entrySet()) {
            if (!queryAtom.headAtom.literal.predicate().equals(entry.getKey().predicate()) || !matching.subsumption(new Clause(queryAtom.headAtom.literal), new Clause(entry.getKey())).booleanValue()) continue;
            queryLiterals.add(entry.getKey());
        }
        return queryLiterals;
    }

    protected NeuralNetBuilder loadAllNeuronsStartingFromQueryLiterals(GroundTemplate groundTemplate, List<Literal> queryLiterals, NeuronMaps neuronMaps, NeuralSets currentNeuralSets) {
        HashSet<Literal> closedSet = new HashSet<Literal>();
        for (Literal queryLiteral : queryLiterals) {
            this.recursiveNeuronsCreation(queryLiteral, closedSet, neuronMaps, currentNeuralSets);
            closedSet.add(queryLiteral);
        }
        return this.neuralNetBuilder;
    }

    @NotNull
    protected List<QueryNeuron> getQueryNeurons(QueryAtom queryAtom, NeuronMaps neuronMaps, NeuralNetwork neuralNetwork, List<Literal> queryMatchingLiterals) {
        ArrayList<QueryNeuron> queryNeurons = new ArrayList<QueryNeuron>();
        for (Literal queryLiteral : queryMatchingLiterals) {
            AtomNeurons atomNeuron = neuronMaps.atomNeurons.get(queryLiteral);
            if (atomNeuron == null) {
                LOG.severe("Query not matched!");
            }
            QueryNeuron queryNeuron = new QueryNeuron(queryAtom.ID, queryAtom.position, queryAtom.importance, atomNeuron, neuralNetwork);
            queryNeurons.add(queryNeuron);
        }
        return queryNeurons;
    }
}

