/*
 * Decompiled with CFR 0.152.
 */
package networks.structure.building;

import com.sun.istack.internal.NotNull;
import constructs.building.factories.WeightFactory;
import constructs.example.LogicSample;
import constructs.example.QueryAtom;
import constructs.template.components.GroundHeadRule;
import constructs.template.components.GroundRule;
import grounding.GroundTemplate;
import grounding.GroundingSample;
import ida.ilp.logic.Clause;
import ida.ilp.logic.Literal;
import ida.ilp.logic.subsumption.Matching;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import networks.structure.building.NeuralProcessingSample;
import networks.structure.building.NeuronMaps;
import networks.structure.building.builders.NeuralNetBuilder;
import networks.structure.components.NeuralNetwork;
import networks.structure.components.NeuronSets;
import networks.structure.components.neurons.QueryNeuron;
import networks.structure.components.neurons.types.AtomNeurons;
import networks.structure.components.types.DetailedNetwork;
import settings.Settings;

public class Neuralizer {
    private static final Logger LOG = Logger.getLogger(Neuralizer.class.getName());
    private Settings settings;
    public NeuralNetBuilder neuralNetBuilder;

    public Neuralizer(Settings settings) {
        this.settings = settings;
        this.neuralNetBuilder = new NeuralNetBuilder(settings);
    }

    public Neuralizer(Settings settings, WeightFactory weightFactory) {
        this.settings = settings;
        this.neuralNetBuilder = new NeuralNetBuilder(settings);
        this.neuralNetBuilder.neuralBuilder.weightFactory = weightFactory;
    }

    public List<NeuralProcessingSample> neuralize(GroundTemplate groundTemplate, List<? extends LogicSample> samples) {
        DetailedNetwork neuralNetwork;
        this.neuralNetBuilder.setNeuronMaps(groundTemplate.neuronMaps);
        NeuronSets createdNeurons = new NeuronSets();
        ArrayList<Literal> queryMatchingLiterals = new ArrayList<Literal>();
        ArrayList<LogicSample> origSamples = new ArrayList<LogicSample>();
        for (LogicSample logicSample : samples) {
            List<Literal> foundQueries = this.getQueryMatchingLiterals((QueryAtom)logicSample.query, groundTemplate.groundRules);
            for (Literal foundQuery : foundQueries) {
                queryMatchingLiterals.add(foundQuery);
                origSamples.add(logicSample);
            }
        }
        if (this.settings.forceFullNetworks) {
            neuralNetwork = this.blindNeuralization(groundTemplate, createdNeurons);
        } else {
            this.neuralNetBuilder = this.loadAllNeuronsStartingFromQueryLiterals(groundTemplate, queryMatchingLiterals, createdNeurons);
            neuralNetwork = this.getDetailedNetwork(createdNeurons, groundTemplate, queryMatchingLiterals);
        }
        ArrayList<NeuralProcessingSample> arrayList = new ArrayList<NeuralProcessingSample>();
        for (int i = 0; i < queryMatchingLiterals.size(); ++i) {
            LogicSample logicSample = (LogicSample)origSamples.get(i);
            QueryAtom queryAtom = (QueryAtom)logicSample.query;
            AtomNeurons atomNeuron = this.neuralNetBuilder.getNeuronMaps().atomNeurons.get(queryMatchingLiterals.get(i));
            if (atomNeuron == null) {
                LOG.severe("No inference network created for " + queryAtom);
            }
            QueryNeuron queryNeuron = new QueryNeuron(queryAtom.ID, queryAtom.position, queryAtom.importance, atomNeuron, neuralNetwork);
            NeuralProcessingSample neuralProcessingSample = new NeuralProcessingSample(logicSample.target, queryNeuron);
            arrayList.add(neuralProcessingSample);
        }
        groundTemplate.neuronMaps = this.neuralNetBuilder.getNeuronMaps();
        return arrayList;
    }

    public List<NeuralProcessingSample> neuralize(GroundingSample groundingSample) {
        this.neuralNetBuilder.setNeuronMaps(groundingSample.groundingWrap.getGroundTemplate().neuronMaps);
        NeuronSets currentNeuronSets = new NeuronSets();
        List<QueryNeuron> queryNeurons = this.supervisedNeuralization(groundingSample, currentNeuronSets);
        if (queryNeurons.isEmpty()) {
            LOG.severe("No inference network created for " + groundingSample.query);
        }
        groundingSample.groundingWrap.getGroundTemplate().neuronMaps = this.neuralNetBuilder.getNeuronMaps();
        List<NeuralProcessingSample> samples = queryNeurons.stream().map(queryNeuron -> new NeuralProcessingSample(groundingSample.target, (QueryNeuron)queryNeuron)).collect(Collectors.toList());
        return samples;
    }

    private List<QueryNeuron> supervisedNeuralization(GroundingSample groundingSample, NeuronSets createdNeurons) {
        DetailedNetwork neuralNetwork;
        QueryAtom queryAtom = (QueryAtom)groundingSample.query;
        GroundTemplate groundTemplate = groundingSample.groundingWrap.getGroundTemplate();
        List<Literal> queryMatchingLiterals = this.getQueryMatchingLiterals(queryAtom, groundTemplate.groundRules);
        LOG.finer("Obtained QueryMatchingLiterals: " + queryMatchingLiterals);
        if (this.settings.forceFullNetworks) {
            neuralNetwork = this.blindNeuralization(groundTemplate, createdNeurons);
        } else {
            this.neuralNetBuilder = this.loadAllNeuronsStartingFromQueryLiterals(groundTemplate, queryMatchingLiterals, createdNeurons);
            neuralNetwork = this.getDetailedNetwork(createdNeurons, groundTemplate, queryMatchingLiterals);
        }
        return this.getQueryNeurons(queryAtom, this.neuralNetBuilder.getNeuronMaps(), neuralNetwork, queryMatchingLiterals);
    }

    private DetailedNetwork blindNeuralization(GroundTemplate groundTemplate, NeuronSets currentNeuronSets) {
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : groundTemplate.neuronMaps.groundRules.entrySet()) {
            this.neuralNetBuilder.loadNeuronsFromRules(entry.getKey(), entry.getValue(), currentNeuronSets);
        }
        groundTemplate.neuronMaps.groundRules.clear();
        return this.getDetailedNetwork(currentNeuronSets, groundTemplate, null);
    }

    private DetailedNetwork getDetailedNetwork(NeuronSets createdNeurons, GroundTemplate groundTemplate, List<Literal> queryMatchingLiterals) {
        if (this.neuralNetBuilder.neuralBuilder.neuronFactory.neuronMaps.factNeurons.isEmpty() || this.settings.groundingMode == Settings.GroundingMode.SEQUENTIAL) {
            this.neuralNetBuilder.loadNeuronsFromFacts(groundTemplate.neuronMaps.groundFacts);
        }
        LOG.fine("Neurons created: " + this.neuralNetBuilder.getNeuronMaps());
        this.neuralNetBuilder.connectAllNeurons(createdNeurons);
        LOG.fine("All neurons connected.");
        DetailedNetwork neuralNetwork = this.neuralNetBuilder.finalizeStoredNetwork(groundTemplate.getId(), createdNeurons, queryMatchingLiterals);
        LOG.fine("Final neural network created: " + neuralNetwork);
        return neuralNetwork;
    }

    private void recursiveNeuronsCreation(@NotNull Literal literal, GroundTemplate groundTemplate, Set<Literal> closedSet, NeuronSets currentNeuronSets) {
        if (closedSet.contains(literal)) {
            return;
        }
        closedSet.add(literal);
        LinkedHashMap ruleMap = (LinkedHashMap)groundTemplate.neuronMaps.groundRules.remove(literal);
        if (ruleMap != null) {
            this.neuralNetBuilder.loadNeuronsFromRules(literal, ruleMap, currentNeuronSets);
            for (LinkedHashSet groundings : ruleMap.values()) {
                for (GroundRule grounding : groundings) {
                    for (Literal bodyAtom : grounding.groundBody) {
                        this.recursiveNeuronsCreation(bodyAtom, groundTemplate, closedSet, currentNeuronSets);
                    }
                }
            }
        }
    }

    @NotNull
    protected List<Literal> getQueryMatchingLiterals(QueryAtom queryAtom, @NotNull LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> groundRules) {
        if (!queryAtom.headAtom.literal.containsVariable()) {
            ArrayList<Literal> queries = new ArrayList<Literal>();
            queries.add(queryAtom.headAtom.literal);
            return queries;
        }
        Matching matching = new Matching();
        ArrayList<Literal> queryLiterals = new ArrayList<Literal>();
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : groundRules.entrySet()) {
            if (!queryAtom.headAtom.literal.predicate().equals(entry.getKey().predicate()) || !matching.subsumption(new Clause(queryAtom.headAtom.literal), new Clause(entry.getKey())).booleanValue()) continue;
            queryLiterals.add(entry.getKey());
        }
        return queryLiterals;
    }

    protected NeuralNetBuilder loadAllNeuronsStartingFromQueryLiterals(GroundTemplate groundTemplate, List<Literal> queryLiterals, NeuronSets currentNeuronSets) {
        HashSet<Literal> closedSet = new HashSet<Literal>();
        for (Literal queryLiteral : queryLiterals) {
            this.recursiveNeuronsCreation(queryLiteral, groundTemplate, closedSet, currentNeuronSets);
            closedSet.add(queryLiteral);
        }
        return this.neuralNetBuilder;
    }

    @NotNull
    protected List<QueryNeuron> getQueryNeurons(QueryAtom queryAtom, NeuronMaps neuronMaps, NeuralNetwork neuralNetwork, List<Literal> queryMatchingLiterals) {
        ArrayList<QueryNeuron> queryNeurons = new ArrayList<QueryNeuron>();
        for (Literal queryLiteral : queryMatchingLiterals) {
            AtomNeurons atomNeuron = neuronMaps.atomNeurons.get(queryLiteral);
            if (atomNeuron == null) {
                LOG.severe("Query not matched!");
            }
            QueryNeuron queryNeuron = new QueryNeuron(queryAtom.ID, queryAtom.position, queryAtom.importance, atomNeuron, neuralNetwork);
            queryNeurons.add(queryNeuron);
        }
        return queryNeurons;
    }
}

