/*
 * Decompiled with CFR 0.152.
 */
package de.learnlib.algorithm.procedural.spmm;

import com.google.common.collect.Maps;
import de.learnlib.AccessSequenceTransformer;
import de.learnlib.algorithm.LearnerConstructor;
import de.learnlib.algorithm.LearningAlgorithm;
import de.learnlib.algorithm.procedural.SymbolWrapper;
import de.learnlib.algorithm.procedural.spmm.ATManager;
import de.learnlib.algorithm.procedural.spmm.MappingSPMM;
import de.learnlib.algorithm.procedural.spmm.ProceduralMembershipOracle;
import de.learnlib.algorithm.procedural.spmm.manager.OptimizingATManager;
import de.learnlib.oracle.MembershipOracle;
import de.learnlib.query.DefaultQuery;
import de.learnlib.util.MQUtil;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import net.automatalib.alphabet.DefaultProceduralInputAlphabet;
import net.automatalib.alphabet.GrowingMapAlphabet;
import net.automatalib.alphabet.ProceduralInputAlphabet;
import net.automatalib.alphabet.SupportsGrowingAlphabet;
import net.automatalib.automaton.procedural.EmptySPMM;
import net.automatalib.automaton.procedural.SPMM;
import net.automatalib.automaton.procedural.StackSPMM;
import net.automatalib.automaton.transducer.MealyMachine;
import net.automatalib.common.util.Pair;
import net.automatalib.common.util.mapping.Mapping;
import net.automatalib.util.automaton.Automata;
import net.automatalib.util.automaton.procedural.SPMMs;
import net.automatalib.word.Word;
import net.automatalib.word.WordBuilder;

public class SPMMLearner<I, O, L extends LearningAlgorithm.MealyLearner<SymbolWrapper<I>, O> & SupportsGrowingAlphabet<SymbolWrapper<I>>>
implements LearningAlgorithm<SPMM<?, I, ?, O>, I, Word<O>> {
    private final ProceduralInputAlphabet<I> alphabet;
    private final O errorOutput;
    private final MembershipOracle<I, Word<O>> oracle;
    private final Mapping<I, LearnerConstructor<L, SymbolWrapper<I>, Word<O>>> learnerConstructors;
    private final ATManager<I, O> atManager;
    private final Map<I, L> learners;
    private I initialCallSymbol;
    private O initialOutputSymbol;
    private final Map<I, SymbolWrapper<I>> mapping;

    public SPMMLearner(ProceduralInputAlphabet<I> alphabet, O errorOutput, MembershipOracle<I, Word<O>> oracle, LearnerConstructor<L, SymbolWrapper<I>, Word<O>> learnerConstructor) {
        this(alphabet, errorOutput, oracle, i -> learnerConstructor, new OptimizingATManager<I, O>(alphabet, errorOutput));
    }

    public SPMMLearner(ProceduralInputAlphabet<I> alphabet, O errorOutput, MembershipOracle<I, Word<O>> oracle, Mapping<I, LearnerConstructor<L, SymbolWrapper<I>, Word<O>>> learnerConstructors, ATManager<I, O> atManager) {
        this.alphabet = alphabet;
        this.errorOutput = errorOutput;
        this.oracle = oracle;
        this.learnerConstructors = learnerConstructors;
        this.atManager = atManager;
        this.learners = Maps.newHashMapWithExpectedSize(this.alphabet.getNumCalls());
        this.mapping = Maps.newHashMapWithExpectedSize(this.alphabet.size());
        for (Object i : this.alphabet.getInternalAlphabet()) {
            SymbolWrapper wrapper = new SymbolWrapper(i, true);
            this.mapping.put(i, wrapper);
        }
        SymbolWrapper<I> wrapper = new SymbolWrapper<I>(this.alphabet.getReturnSymbol(), false);
        this.mapping.put(this.alphabet.getReturnSymbol(), wrapper);
    }

    @Override
    public void startLearning() {
    }

    @Override
    public boolean refineHypothesis(DefaultQuery<I, Word<O>> defaultQuery) {
        assert (this.alphabet.isReturnMatched(defaultQuery.getInput()));
        boolean changed = this.extractUsefulInformationFromCounterExample(defaultQuery);
        while (this.refineHypothesisInternal(defaultQuery)) {
            changed = true;
        }
        this.ensureReturnClosure();
        assert (SPMMs.isValid(this.getHypothesisModel()));
        return changed;
    }

    private boolean refineHypothesisInternal(DefaultQuery<I, Word<O>> defaultQuery) {
        Object hypothesis = this.getHypothesisModel();
        if (!MQUtil.isCounterexample(defaultQuery, hypothesis)) {
            return false;
        }
        Word input = defaultQuery.getInput();
        Word<O> output = defaultQuery.getOutput();
        int mismatchIdx = this.detectMismatchingIdx((SPMM)hypothesis, input, output);
        int callIdx = this.alphabet.findCallIndex(input, mismatchIdx);
        Object procedure = input.getSymbol(callIdx);
        Pair<Word<I>, Word<O>> localTraces = this.alphabet.project(input.subWord(callIdx + 1, mismatchIdx + 1), output.subWord(callIdx + 1, mismatchIdx + 1), 0);
        DefaultQuery<SymbolWrapper<I>, Word<O>> localCE = this.constructLocalCE(localTraces.getFirst(), localTraces.getSecond());
        boolean localRefinement = ((LearningAlgorithm.MealyLearner)this.learners.get(procedure)).refineHypothesis(localCE);
        assert (localRefinement);
        return true;
    }

    @Override
    public SPMM<?, I, ?, O> getHypothesisModel() {
        SymbolWrapper<I> w;
        if (this.learners.isEmpty()) {
            return new EmptySPMM<I, O>(this.alphabet, this.errorOutput);
        }
        GrowingMapAlphabet internalAlphabet = new GrowingMapAlphabet();
        GrowingMapAlphabet callAlphabet = new GrowingMapAlphabet();
        Map<I, MealyMachine<?, SymbolWrapper<I>, ?, O>> procedures = this.getSubModels();
        HashMap<SymbolWrapper<I>, MealyMachine<?, SymbolWrapper<I>, ?, O>> mappedProcedures = Maps.newHashMapWithExpectedSize(procedures.size());
        for (Map.Entry<I, MealyMachine<?, SymbolWrapper<I>, ?, O>> e : procedures.entrySet()) {
            w = this.mapping.get(e.getKey());
            assert (w != null);
            mappedProcedures.put(w, e.getValue());
            callAlphabet.add(w);
        }
        for (Map.Entry<Object, MealyMachine<Object, SymbolWrapper<I>, Object, O>> i : this.alphabet.getInternalAlphabet()) {
            w = this.mapping.get(i);
            assert (w != null);
            internalAlphabet.add(w);
        }
        SymbolWrapper<I> returnSymbol = this.mapping.get(this.alphabet.getReturnSymbol());
        assert (returnSymbol != null);
        DefaultProceduralInputAlphabet<SymbolWrapper<I>> mappedAlphabet = new DefaultProceduralInputAlphabet<SymbolWrapper<I>>(internalAlphabet, callAlphabet, returnSymbol);
        StackSPMM delegate = new StackSPMM(mappedAlphabet, this.mapping.get(this.initialCallSymbol), this.initialOutputSymbol, this.errorOutput, mappedProcedures);
        return new MappingSPMM(this.alphabet, this.errorOutput, this.mapping, delegate);
    }

    private boolean extractUsefulInformationFromCounterExample(DefaultQuery<I, Word<O>> defaultQuery) {
        Word input = defaultQuery.getInput();
        Word<O> output = defaultQuery.getOutput();
        this.initialCallSymbol = input.firstSymbol();
        this.initialOutputSymbol = output.firstSymbol();
        Pair<Set<I>, Set<I>> newSeqs = this.atManager.scanCounterexample(defaultQuery);
        Set<I> newCalls = newSeqs.getFirst();
        Set<I> newTerms = newSeqs.getSecond();
        boolean update = false;
        for (I call : newTerms) {
            SymbolWrapper<I> sym = new SymbolWrapper<I>(call, true);
            this.mapping.put(call, sym);
            for (Object learner : this.learners.values()) {
                ((SupportsGrowingAlphabet)learner).addAlphabetSymbol(sym);
                update = true;
            }
        }
        for (I sym : newCalls) {
            update = true;
            LearningAlgorithm.MealyLearner newLearner = (LearningAlgorithm.MealyLearner)this.learnerConstructors.get(sym).constructLearner(new GrowingMapAlphabet<SymbolWrapper<I>>(this.mapping.values()), new ProceduralMembershipOracle<I, O>(this.alphabet, this.oracle, sym, this.errorOutput, this.atManager));
            newLearner.startLearning();
            this.learners.put(sym, newLearner);
            Set<I> newTS = this.atManager.scanProcedures(Collections.singletonMap(sym, (MealyMachine)newLearner.getHypothesisModel()), this.learners, this.mapping.values());
            for (Object call : newTS) {
                SymbolWrapper wrapper = new SymbolWrapper(call, true);
                this.mapping.put(call, wrapper);
                for (LearningAlgorithm.MealyLearner learner : this.learners.values()) {
                    ((SupportsGrowingAlphabet)((Object)learner)).addAlphabetSymbol(wrapper);
                }
            }
            if (this.mapping.containsKey(sym)) continue;
            SymbolWrapper<I> wrapper = new SymbolWrapper<I>(sym, false);
            this.mapping.put(sym, wrapper);
            for (LearningAlgorithm.MealyLearner learner : this.learners.values()) {
                ((SupportsGrowingAlphabet)((Object)learner)).addAlphabetSymbol(wrapper);
            }
        }
        return update;
    }

    private Map<I, MealyMachine<?, SymbolWrapper<I>, ?, O>> getSubModels() {
        HashMap<I, MealyMachine> subModels = Maps.newHashMapWithExpectedSize(this.learners.size());
        for (Map.Entry<I, L> entry : this.learners.entrySet()) {
            subModels.put(entry.getKey(), (MealyMachine)((LearningAlgorithm.MealyLearner)entry.getValue()).getHypothesisModel());
        }
        return subModels;
    }

    private DefaultQuery<SymbolWrapper<I>, Word<O>> constructLocalCE(Word<I> input, Word<O> output) {
        WordBuilder<SymbolWrapper<I>> wb = new WordBuilder<SymbolWrapper<I>>(input.length());
        for (I i : input) {
            wb.append(this.mapping.get(i));
        }
        return new DefaultQuery<SymbolWrapper<I>, Word<O>>(wb.toWord(), output);
    }

    private void ensureReturnClosure() {
        for (LearningAlgorithm.MealyLearner learner : this.learners.values()) {
            boolean stable = false;
            while (!stable) {
                stable = this.ensureReturnClosure((MealyMachine)learner.getHypothesisModel(), this.mapping.values(), learner);
            }
        }
    }

    private <S, T> boolean ensureReturnClosure(MealyMachine<S, SymbolWrapper<I>, T, O> hyp, Collection<SymbolWrapper<I>> inputs, L learner) {
        HashSet<Word<SymbolWrapper<I>>> cover = new HashSet<Word<SymbolWrapper<I>>>();
        for (Word<SymbolWrapper<I>> word : Automata.stateCover(hyp, inputs)) {
            cover.add(((AccessSequenceTransformer)learner).transformAccessSequence(word));
        }
        for (Word<SymbolWrapper<Object>> word : cover) {
            Object state = hyp.getState(word);
            for (SymbolWrapper<I> i : inputs) {
                if (!Objects.equals(i.getDelegate(), this.alphabet.getReturnSymbol())) continue;
                Object succ = hyp.getSuccessor(state, i);
                for (SymbolWrapper<I> next : inputs) {
                    Object succOut = hyp.getOutput(succ, next);
                    if (Objects.equals(this.errorOutput, succOut)) continue;
                    Word<SymbolWrapper<SymbolWrapper<I>>> lp = word.append(i);
                    DefaultQuery<SymbolWrapper<I>, Word<O>> ce = new DefaultQuery<SymbolWrapper<I>, Word<O>>(Word.epsilon(), lp.append(next), ((Word)hyp.computeOutput(lp)).append(this.errorOutput));
                    boolean refined = learner.refineHypothesis(ce);
                    assert (refined);
                    return false;
                }
            }
        }
        return true;
    }

    private <S, T> int detectMismatchingIdx(SPMM<S, I, T, O> spmm, Word<I> input, Word<O> output) {
        Iterator<I> inIter = input.iterator();
        Iterator<O> outIter = output.iterator();
        Object stateIter = spmm.getInitialState();
        int idx = 0;
        while (inIter.hasNext() && outIter.hasNext()) {
            I i = inIter.next();
            O o = outIter.next();
            Object t = spmm.getTransition(stateIter, i);
            if (t == null || !Objects.equals(o, spmm.getTransitionOutput(t))) {
                return idx;
            }
            stateIter = spmm.getSuccessor(t);
            ++idx;
        }
        throw new IllegalArgumentException("Non-counterexamples shouldn't be scanned for a mis-match");
    }
}

