/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.logic.grounding;

import com.sun.istack.internal.NotNull;
import cz.cvut.fel.ida.learning.Example;
import cz.cvut.fel.ida.logic.Clause;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.constructs.example.QueryAtom;
import cz.cvut.fel.ida.logic.constructs.example.ValuedFact;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundHeadRule;
import cz.cvut.fel.ida.logic.constructs.template.components.GroundRule;
import cz.cvut.fel.ida.logic.constructs.template.types.GraphTemplate;
import cz.cvut.fel.ida.logic.subsumption.Matching;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

public class GroundTemplate
extends GraphTemplate
implements Example {
    static int counter = 0;
    @NotNull
    public LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> groundRules;
    @NotNull
    public Map<Literal, ValuedFact> groundFacts;
    public Set<Literal> derivedGroundFacts;

    public GroundTemplate() {
        this.name = "g" + counter++;
    }

    public GroundTemplate(LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> groundRules, Map<Literal, ValuedFact> groundFacts) {
        this();
        this.groundRules = groundRules;
        this.groundFacts = groundFacts;
        this.derivedGroundFacts = this.getFactsFromGroundRules(groundRules);
    }

    public GroundTemplate(GroundTemplate other) {
        this();
        this.groundRules = other.groundRules;
        this.groundFacts = other.groundFacts;
        this.derivedGroundFacts = other.derivedGroundFacts;
    }

    @Override
    public String getName() {
        return this.name;
    }

    @Override
    public String getId() {
        return this.getName();
    }

    @Override
    public Integer getNeuronCount() {
        return this.groundRules.size();
    }

    private Set<Literal> getFactsFromGroundRules(LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> groundRules) {
        HashSet<Literal> derivedFacts = new HashSet<Literal>();
        derivedFacts.addAll(groundRules.keySet());
        return derivedFacts;
    }

    public GroundTemplate diffAgainst(GroundTemplate memory) {
        GroundTemplate diff = new GroundTemplate();
        diff.groundRules = new LinkedHashMap();
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : this.groundRules.entrySet()) {
            LinkedHashMap put = diff.groundRules.put(entry.getKey(), new LinkedHashMap());
            for (Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>> entry2 : entry.getValue().entrySet()) {
                LinkedHashSet put1 = put.put(entry2.getKey(), new LinkedHashSet());
                put1.addAll(entry2.getValue());
            }
        }
        diff.groundFacts = new HashMap<Literal, ValuedFact>();
        diff.groundFacts.putAll(this.groundFacts);
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : memory.groundRules.entrySet()) {
            for (Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>> entry2 : entry.getValue().entrySet()) {
                for (GroundRule rule : entry2.getValue()) {
                    LinkedHashSet<GroundRule> rules = diff.groundRules.get(entry.getKey()).get(entry2.getKey());
                    rules.remove(rule);
                }
            }
        }
        diff.groundFacts.keySet().removeAll(memory.groundFacts.keySet());
        return diff;
    }

    @Override
    public GroundTemplate prune(QueryAtom queryAtom) {
        GroundTemplate groundTemplate = new GroundTemplate(this);
        LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> support = new LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>>();
        Matching matching = new Matching();
        for (Map.Entry<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> entry : this.groundRules.entrySet()) {
            if (!queryAtom.headAtom.literal.predicate().equals(entry.getKey().predicate()) || !matching.subsumption(new Clause(queryAtom.headAtom.literal), new Clause(entry.getKey())).booleanValue()) continue;
            LinkedHashMap ruleMap = support.computeIfAbsent(entry.getKey(), f -> new LinkedHashMap());
            for (Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>> groundings : entry.getValue().entrySet()) {
                LinkedHashSet weightedRules = ruleMap.computeIfAbsent(groundings.getKey(), f -> new LinkedHashSet());
                for (GroundRule grounding : groundings.getValue()) {
                    weightedRules.add(grounding);
                    this.recursePrune(grounding, support, new HashSet<Literal>());
                }
            }
        }
        groundTemplate.groundRules = support;
        return this;
    }

    private void recursePrune(GroundRule grounding, LinkedHashMap<Literal, LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>>> support, Set<Literal> closedList) {
        for (Literal bodyAtom : grounding.groundBody) {
            if (closedList.contains(bodyAtom)) continue;
            closedList.add(bodyAtom);
            LinkedHashMap<GroundHeadRule, LinkedHashSet<GroundRule>> validRules = this.groundRules.get(bodyAtom);
            LinkedHashMap nextRules = support.computeIfAbsent(bodyAtom, f -> new LinkedHashMap());
            for (Map.Entry<GroundHeadRule, LinkedHashSet<GroundRule>> validGroundings : validRules.entrySet()) {
                LinkedHashSet weightedRules = nextRules.computeIfAbsent(validGroundings.getKey(), f -> new LinkedHashSet());
                for (GroundRule nextGrounding : validGroundings.getValue()) {
                    weightedRules.add(nextGrounding);
                    this.recursePrune(nextGrounding, support, closedList);
                }
            }
        }
    }
}

