/*
 * Decompiled with CFR 0.152.
 */
package de.learnlib.algorithm.procedural.spa;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import de.learnlib.acex.AbstractBaseCounterexample;
import de.learnlib.acex.AcexAnalyzer;
import de.learnlib.acex.AcexAnalyzers;
import de.learnlib.algorithm.LearnerConstructor;
import de.learnlib.algorithm.LearningAlgorithm;
import de.learnlib.algorithm.procedural.spa.ATRManager;
import de.learnlib.algorithm.procedural.spa.ProceduralMembershipOracle;
import de.learnlib.algorithm.procedural.spa.manager.OptimizingATRManager;
import de.learnlib.oracle.MembershipOracle;
import de.learnlib.query.DefaultQuery;
import de.learnlib.util.MQUtil;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import net.automatalib.alphabet.GrowingMapAlphabet;
import net.automatalib.alphabet.ProceduralInputAlphabet;
import net.automatalib.alphabet.SupportsGrowingAlphabet;
import net.automatalib.automaton.fsa.DFA;
import net.automatalib.automaton.procedural.EmptySPA;
import net.automatalib.automaton.procedural.SPA;
import net.automatalib.automaton.procedural.StackSPA;
import net.automatalib.common.util.mapping.Mapping;
import net.automatalib.word.Word;
import net.automatalib.word.WordBuilder;
import org.checkerframework.checker.nullness.qual.NonNull;

public class SPALearner<I, L extends LearningAlgorithm.DFALearner<I> & SupportsGrowingAlphabet<I>>
implements LearningAlgorithm<SPA<?, I>, I, Boolean> {
    private final ProceduralInputAlphabet<I> alphabet;
    private final MembershipOracle<I, Boolean> oracle;
    private final Mapping<I, LearnerConstructor<L, I, Boolean>> learnerConstructors;
    private final AcexAnalyzer analyzer;
    private final ATRManager<I> atrManager;
    private final Map<I, L> subLearners;
    private final Set<I> activeAlphabet;
    private I initialCallSymbol;

    public SPALearner(ProceduralInputAlphabet<I> alphabet, MembershipOracle<I, Boolean> oracle, LearnerConstructor<L, I, Boolean> learnerConstructor) {
        this(alphabet, oracle, i -> learnerConstructor, AcexAnalyzers.BINARY_SEARCH_FWD, new OptimizingATRManager<I>(alphabet));
    }

    public SPALearner(ProceduralInputAlphabet<I> alphabet, MembershipOracle<I, Boolean> oracle, Mapping<I, LearnerConstructor<L, I, Boolean>> learnerConstructors, AcexAnalyzer analyzer, ATRManager<I> atrManager) {
        this.alphabet = alphabet;
        this.oracle = oracle;
        this.learnerConstructors = learnerConstructors;
        this.analyzer = analyzer;
        this.atrManager = atrManager;
        this.subLearners = Maps.newHashMapWithExpectedSize(this.alphabet.getNumCalls());
        this.activeAlphabet = Sets.newHashSetWithExpectedSize(alphabet.getNumCalls() + alphabet.getNumInternals());
        this.activeAlphabet.addAll(alphabet.getInternalAlphabet());
    }

    @Override
    public void startLearning() {
    }

    @Override
    public boolean refineHypothesis(DefaultQuery<I, Boolean> defaultQuery) {
        assert (this.alphabet.isWellMatched(defaultQuery.getInput()));
        boolean changed = this.extractUsefulInformationFromCounterExample(defaultQuery);
        while (this.refineHypothesisInternal(defaultQuery)) {
            changed = true;
        }
        return changed;
    }

    private boolean refineHypothesisInternal(DefaultQuery<I, Boolean> defaultQuery) {
        Object hypothesis = this.getHypothesisModel();
        if (!MQUtil.isCounterexample(defaultQuery, hypothesis)) {
            return false;
        }
        boolean localRefinement = this.updateATRAndCheckTSConformance((SPA<?, I>)hypothesis);
        if (!MQUtil.isCounterexample(defaultQuery, hypothesis)) {
            return localRefinement;
        }
        Word input = defaultQuery.getInput();
        List<Integer> returnIndices = this.determineReturnIndices(input);
        int idx = this.analyzer.analyzeAbstractCounterexample(new Acex(input, defaultQuery.getOutput().booleanValue() ? ((SPA)hypothesis)::accepts : this.oracle::answerQuery, returnIndices));
        int returnIdx = returnIndices.get(idx);
        int callIdx = this.alphabet.findCallIndex(input, returnIdx);
        Object procedure = input.getSymbol(callIdx);
        Word<I> localTrace = this.alphabet.project(input.subWord(callIdx + 1, returnIdx), 0);
        DefaultQuery<I, Boolean> localCE = new DefaultQuery<I, Boolean>(localTrace, defaultQuery.getOutput());
        assert (localRefinement |= ((LearningAlgorithm.DFALearner)this.subLearners.get(procedure)).refineHypothesis(localCE));
        return true;
    }

    @Override
    public SPA<?, I> getHypothesisModel() {
        if (this.subLearners.isEmpty()) {
            return new EmptySPA<I>(this.alphabet);
        }
        return new StackSPA(this.alphabet, this.initialCallSymbol, this.getSubModels());
    }

    private boolean extractUsefulInformationFromCounterExample(DefaultQuery<I, Boolean> defaultQuery) {
        if (!defaultQuery.getOutput().booleanValue()) {
            return false;
        }
        Word input = defaultQuery.getInput();
        this.initialCallSymbol = input.firstSymbol();
        Set<I> newProcedures = this.atrManager.scanPositiveCounterexample(input);
        for (I sym : newProcedures) {
            LearningAlgorithm.DFALearner newLearner = (LearningAlgorithm.DFALearner)this.learnerConstructors.get(sym).constructLearner(new GrowingMapAlphabet(this.alphabet.getInternalAlphabet()), new ProceduralMembershipOracle<I>(this.alphabet, this.oracle, sym, this.atrManager));
            for (I call : this.subLearners.keySet()) {
                ((SupportsGrowingAlphabet)((Object)newLearner)).addAlphabetSymbol(call);
            }
            newLearner.startLearning();
            this.subLearners.put(sym, newLearner);
            this.atrManager.scanProcedures(Collections.singletonMap(sym, (DFA)newLearner.getHypothesisModel()), this.subLearners, this.activeAlphabet);
            this.activeAlphabet.add(sym);
            for (LearningAlgorithm.DFALearner learner : this.subLearners.values()) {
                ((SupportsGrowingAlphabet)((Object)learner)).addAlphabetSymbol(sym);
            }
        }
        if (!newProcedures.isEmpty()) {
            this.atrManager.scanProcedures(this.getSubModels(), this.subLearners, this.activeAlphabet);
            return true;
        }
        return false;
    }

    private Map<I, DFA<?, I>> getSubModels() {
        HashMap<I, DFA> subModels = Maps.newHashMapWithExpectedSize(this.subLearners.size());
        for (Map.Entry<I, L> entry : this.subLearners.entrySet()) {
            subModels.put(entry.getKey(), (DFA)((LearningAlgorithm.DFALearner)entry.getValue()).getHypothesisModel());
        }
        return subModels;
    }

    private boolean updateATRAndCheckTSConformance(SPA<?, I> hypothesis) {
        boolean refinement = false;
        Map subModels = hypothesis.getProcedures();
        while (this.checkAndEnsureTSConformance(subModels)) {
            refinement = true;
            subModels = this.getSubModels();
            this.atrManager.scanProcedures(subModels, this.subLearners, this.activeAlphabet);
        }
        return refinement;
    }

    private List<Integer> determineReturnIndices(Word<I> input) {
        ArrayList<Integer> returnIndices = new ArrayList<Integer>();
        for (int i = 0; i < input.length(); ++i) {
            if (!this.alphabet.isReturnSymbol(input.getSymbol(i))) continue;
            returnIndices.add(i);
        }
        return returnIndices;
    }

    private boolean checkAndEnsureTSConformance(Map<I, DFA<?, I>> subModels) {
        boolean refinement = false;
        for (I procedure : this.subLearners.keySet()) {
            Word<I> terminatingSequence = this.atrManager.getTerminatingSequence(procedure);
            WordBuilder<I> embeddedTS = new WordBuilder<I>(terminatingSequence.size() + 2);
            embeddedTS.append(procedure);
            embeddedTS.append(terminatingSequence);
            embeddedTS.append(this.alphabet.getReturnSymbol());
            refinement |= this.checkSingleTerminatingSequence(embeddedTS.toWord(), subModels);
        }
        return refinement;
    }

    private boolean checkSingleTerminatingSequence(Word<I> input, Map<I, DFA<?, I>> hypotheses) {
        boolean refinement = false;
        for (int i = 0; i < input.size(); ++i) {
            I sym = input.getSymbol(i);
            if (!this.alphabet.isCallSymbol(sym)) continue;
            int returnIdx = this.alphabet.findReturnIndex(input, i + 1);
            Word<I> projectedRun = this.alphabet.project(input.subWord(i + 1, returnIdx), 0);
            @NonNull DFA<?, I> hyp = hypotheses.get(sym);
            if (hyp.accepts(projectedRun)) continue;
            refinement = true;
            ((LearningAlgorithm.DFALearner)this.subLearners.get(sym)).refineHypothesis(new DefaultQuery<I, Boolean>(projectedRun, true));
        }
        return refinement;
    }

    private class Acex
    extends AbstractBaseCounterexample<Boolean> {
        private final Word<I> input;
        private final Predicate<? super Word<I>> oracle;
        private final List<Integer> returnIndices;

        Acex(Word<I> input, Predicate<? super Word<I>> oracle, List<Integer> returnIndices) {
            super(returnIndices.size() + 1);
            this.input = input;
            this.oracle = oracle;
            this.returnIndices = returnIndices;
            this.setEffect(returnIndices.size(), true);
            this.setEffect(0, false);
        }

        @Override
        protected Boolean computeEffect(int index) {
            ArrayDeque<Word<Object>> wordStack = new ArrayDeque<Word<Object>>();
            int idx = this.returnIndices.get(index);
            while (idx > 0) {
                int callIdx = SPALearner.this.alphabet.findCallIndex(this.input, idx);
                Object callSymbol = this.input.getSymbol(callIdx);
                Word normalized = SPALearner.this.alphabet.project(this.input.subWord(callIdx + 1, idx), 0);
                Word<Object> expanded = SPALearner.this.alphabet.expand(normalized, SPALearner.this.atrManager::getTerminatingSequence);
                wordStack.push(expanded.prepend(callSymbol));
                idx = callIdx;
            }
            WordBuilder builder = new WordBuilder();
            wordStack.forEach(builder::append);
            builder.append(this.input.subWord(this.returnIndices.get(index)));
            return this.oracle.test(builder.toWord());
        }

        @Override
        public boolean checkEffects(Boolean eff1, Boolean eff2) {
            return Objects.equals(eff1, eff2);
        }
    }
}

