/*
 * Decompiled with CFR 0.152.
 */
package constructs.template;

import com.sun.istack.internal.Nullable;
import constructs.Atom;
import constructs.Conjunction;
import constructs.example.QueryAtom;
import constructs.example.ValuedFact;
import constructs.template.components.BodyAtom;
import constructs.template.components.WeightedRule;
import constructs.template.types.GraphTemplate;
import grounding.bottomUp.HerbrandModel;
import ida.ilp.logic.HornClause;
import ida.ilp.logic.Literal;
import ida.ilp.logic.Predicate;
import ida.utils.collections.MultiMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import learning.Model;
import networks.computation.evaluation.values.Value;
import networks.computation.training.NeuralModel;
import networks.structure.components.weights.Weight;

public class Template
implements Model<QueryAtom> {
    private static final Logger LOG = Logger.getLogger(Template.class.getName());
    static int counter = 0;
    String id;
    public LinkedHashSet<WeightedRule> rules;
    public LinkedHashSet<ValuedFact> facts;
    public LinkedHashSet<Conjunction> constraints;
    @Nullable
    Set<Literal> inferredLiterals;
    @Nullable
    public Map<HornClause, List<WeightedRule>> hornClauses;

    public Template() {
        this.id = "template" + counter++;
        this.rules = new LinkedHashSet();
        this.facts = new LinkedHashSet();
        this.constraints = new LinkedHashSet();
    }

    public Template(Template other) {
        this.rules = other.rules;
        this.facts = other.facts;
        this.constraints = other.constraints;
    }

    public Template(List<WeightedRule> rules, List<ValuedFact> facts) {
        this();
        this.rules.addAll(rules);
        this.facts.addAll(facts);
    }

    public void addConstraints(List<Conjunction> constr) {
        this.constraints = new LinkedHashSet<Conjunction>(constr);
    }

    @Override
    public String getId() {
        return this.id;
    }

    @Override
    public Value evaluate(QueryAtom query) {
        return null;
    }

    @Override
    public List<Weight> getAllWeights() {
        ArrayList<Weight> weightList = new ArrayList<Weight>();
        for (WeightedRule rule : this.rules) {
            Weight headOffset;
            Weight offset;
            if (rule.getWeight() != null) {
                weightList.add(rule.getWeight());
            }
            if ((offset = rule.getOffset()) != null) {
                offset.isOffset = true;
                weightList.add(offset);
            }
            if ((headOffset = rule.getHead().getOffset()) != null) {
                headOffset.isOffset = true;
                weightList.add(headOffset);
            }
            for (BodyAtom bodyAtom : rule.getBody()) {
                if (bodyAtom.getConjunctWeight() == null) continue;
                weightList.add(bodyAtom.getConjunctWeight());
            }
        }
        List<Weight> uniqueWeights = this.filterUnique(weightList);
        return uniqueWeights;
    }

    private List<Weight> filterUnique(List<Weight> weightList) {
        return weightList.stream().distinct().collect(Collectors.toList());
    }

    public void updateWeightsFrom(NeuralModel neuralModel) {
        Map<Integer, Weight> neuralWeights = neuralModel.mapWeightsToIds();
        List<Weight> templateWeights = this.getAllWeights();
        for (Weight weight : templateWeights) {
            if (!weight.isLearnable()) continue;
            weight.value = neuralWeights.get((Object)Integer.valueOf((int)weight.index)).value;
        }
    }

    public LinkedHashSet<ValuedFact> getValuedFacts() {
        return this.facts;
    }

    public Set<Literal> getAllFacts() {
        if (this.inferredLiterals == null) {
            this.inferredLiterals = this.inferTemplateFacts();
            if (this.inferredLiterals != null) {
                this.inferredLiterals.addAll(this.facts.stream().map(Atom::getLiteral).collect(Collectors.toList()));
            }
        }
        return this.inferredLiterals;
    }

    public void setFacts(LinkedHashSet<ValuedFact> facts) {
        this.facts = facts;
    }

    public Set<Literal> inferTemplateFacts() {
        if (this.facts == null || this.facts.isEmpty()) {
            return null;
        }
        if (this.inferredLiterals == null) {
            this.inferredLiterals = new HashSet<Literal>();
        }
        HerbrandModel herbrandModel = new HerbrandModel();
        Set<Literal> facts = this.facts.stream().map(Atom::getLiteral).collect(Collectors.toSet());
        Set<HornClause> rules = this.rules.stream().map(WeightedRule::toHornClause).collect(Collectors.toSet());
        MultiMap<Predicate, Literal> multiMap = herbrandModel.inferModel(rules, facts);
        multiMap.values().forEach(this.inferredLiterals::addAll);
        return this.inferredLiterals;
    }

    public GraphTemplate prune(QueryAtom query) {
        LOG.warning("Inefficient template pruning");
        return new GraphTemplate(this).prune(query);
    }

    public void addAllFrom(Template template) {
        if (template == this) {
            return;
        }
        this.rules.addAll(template.rules);
        this.facts.addAll(template.facts);
        this.constraints.addAll(template.constraints);
    }

    public String toString() {
        return this.id + ", rules: " + this.rules.size() + ", facts: " + this.facts.size() + ", constraints: " + this.constraints.size();
    }
}

