/*
 * Decompiled with CFR 0.152.
 */
package de.learnlib.algorithm.procedural.spmm.manager;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import de.learnlib.AccessSequenceTransformer;
import de.learnlib.algorithm.procedural.SymbolWrapper;
import de.learnlib.algorithm.procedural.spmm.ATManager;
import de.learnlib.query.DefaultQuery;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import net.automatalib.alphabet.ProceduralInputAlphabet;
import net.automatalib.automaton.transducer.MealyMachine;
import net.automatalib.common.util.Pair;
import net.automatalib.util.automaton.cover.Covers;
import net.automatalib.word.Word;
import net.automatalib.word.WordBuilder;
import org.checkerframework.checker.nullness.qual.Nullable;

public class OptimizingATManager<I, O>
implements ATManager<I, O> {
    private final Map<I, Word<I>> accessSequences;
    private final Map<I, Word<I>> terminatingSequences;
    private final ProceduralInputAlphabet<I> inputAlphabet;
    private final O errorOutput;

    public OptimizingATManager(ProceduralInputAlphabet<I> inputAlphabet, O errorOutput) {
        this.inputAlphabet = inputAlphabet;
        this.errorOutput = errorOutput;
        this.accessSequences = Maps.newHashMapWithExpectedSize(inputAlphabet.getNumCalls());
        this.terminatingSequences = Maps.newHashMapWithExpectedSize(inputAlphabet.getNumCalls());
    }

    @Override
    public Word<I> getAccessSequence(I procedure) {
        assert (this.accessSequences.containsKey(procedure));
        return this.accessSequences.get(procedure);
    }

    @Override
    public Word<I> getTerminatingSequence(I procedure) {
        assert (this.terminatingSequences.containsKey(procedure));
        return this.terminatingSequences.get(procedure);
    }

    @Override
    public Pair<Set<I>, Set<I>> scanCounterexample(DefaultQuery<I, Word<O>> counterexample) {
        HashSet newCalls = Sets.newHashSetWithExpectedSize(this.inputAlphabet.getNumCalls() - this.accessSequences.size());
        HashSet newTerms = Sets.newHashSetWithExpectedSize(this.inputAlphabet.getNumCalls() - this.terminatingSequences.size());
        this.extractPotentialTerminatingSequences(counterexample, newTerms);
        this.extractPotentialAccessSequences(counterexample, newCalls);
        return Pair.of(newCalls, newTerms);
    }

    @Override
    public Set<I> scanProcedures(Map<I, ? extends MealyMachine<?, SymbolWrapper<I>, ?, O>> procedures, Map<I, ? extends AccessSequenceTransformer<SymbolWrapper<I>>> providers, Collection<SymbolWrapper<I>> inputs) {
        HashSet<I> newTS = Sets.newHashSetWithExpectedSize(procedures.size());
        if (!procedures.isEmpty()) {
            SymbolWrapper returnSymbol = inputs.stream().filter(i -> Objects.equals(i.getDelegate(), this.inputAlphabet.getReturnSymbol())).findAny().orElseThrow(IllegalArgumentException::new);
            boolean foundImprovements = false;
            boolean stable = false;
            while (!stable) {
                stable = true;
                for (Map.Entry<I, MealyMachine<?, SymbolWrapper<I>, ?, O>> entry : procedures.entrySet()) {
                    I i2 = entry.getKey();
                    MealyMachine<?, SymbolWrapper<I>, ?, O> automaton = entry.getValue();
                    Word<I> currentTS = this.terminatingSequences.get(i2);
                    assert (providers.containsKey(i2));
                    Word<I> hypTS = this.getShortestHypothesisTS(automaton, providers.get(i2), inputs, returnSymbol);
                    if (hypTS == null || currentTS != null && hypTS.size() >= currentTS.size()) continue;
                    if (currentTS == null) {
                        newTS.add(i2);
                    }
                    this.terminatingSequences.put(i2, hypTS);
                    stable = false;
                    foundImprovements = true;
                }
            }
            if (foundImprovements) {
                this.optimizeSequences(this.accessSequences);
                this.optimizeSequences(this.terminatingSequences);
            }
        }
        return newTS;
    }

    private <S> @Nullable Word<I> getShortestHypothesisTS(MealyMachine<S, SymbolWrapper<I>, ?, O> hyp, AccessSequenceTransformer<SymbolWrapper<I>> asTransformer, Collection<SymbolWrapper<I>> inputs, SymbolWrapper<I> returnSymbol) {
        Iterator<Word<SymbolWrapper<I>>> iter = Covers.stateCoverIterator(hyp, inputs);
        Word<Object> result = null;
        while (iter.hasNext()) {
            Word<SymbolWrapper<I>> cover = iter.next();
            Word<SymbolWrapper<SymbolWrapper>> as = asTransformer.transformAccessSequence(cover);
            Word<SymbolWrapper<I>> asReturn = as.append(returnSymbol);
            if (Objects.equals(this.errorOutput, ((Word)hyp.computeOutput(asReturn)).lastSymbol())) continue;
            Word<Object> ts = this.inputAlphabet.expand(as.transform(SymbolWrapper::getDelegate), this.terminatingSequences::get);
            if (result != null && result.size() <= ts.size()) continue;
            result = ts;
        }
        return result;
    }

    private void optimizeSequences(Map<I, Word<I>> sequences) {
        for (Map.Entry<I, Word<I>> entry : sequences.entrySet()) {
            Word<I> currentSequence = entry.getValue();
            Word<I> minimized = this.minifyWellMatched(currentSequence);
            if (minimized.size() >= currentSequence.size()) continue;
            sequences.put(entry.getKey(), minimized);
        }
    }

    private void extractPotentialTerminatingSequences(DefaultQuery<I, Word<O>> counterexample, Set<I> newProcedures) {
        Word input = counterexample.getInput();
        Word<O> output = counterexample.getOutput();
        for (int i = 0; i < input.size(); ++i) {
            int returnIdx;
            Object sym = input.getSymbol(i);
            if (!this.inputAlphabet.isCallSymbol(sym) || (returnIdx = this.inputAlphabet.findReturnIndex(input, i + 1)) <= 0 || Objects.equals(this.errorOutput, output.getSymbol(returnIdx))) continue;
            Word potentialTermSeq = input.subWord(i + 1, returnIdx);
            Word<I> currentTermSeq = this.terminatingSequences.get(sym);
            if (currentTermSeq == null) {
                newProcedures.add(sym);
                this.terminatingSequences.put(sym, potentialTermSeq);
                continue;
            }
            if (potentialTermSeq.size() >= currentTermSeq.size()) continue;
            this.terminatingSequences.put(sym, potentialTermSeq);
        }
    }

    private void extractPotentialAccessSequences(DefaultQuery<I, Word<O>> counterexample, Set<I> newCalls) {
        Word input = counterexample.getInput();
        Word<O> output = counterexample.getOutput();
        ArrayList<I> asBuilder = new ArrayList<I>(input.size());
        for (int i = 0; i < input.size(); ++i) {
            Object sym = input.getSymbol(i);
            asBuilder.add(sym);
            if (this.inputAlphabet.isCallSymbol(sym)) {
                if (Objects.equals(this.errorOutput, output.getSymbol(i))) {
                    return;
                }
                Word<I> currentAccSeq = this.accessSequences.get(sym);
                if (currentAccSeq == null) {
                    newCalls.add(sym);
                    this.accessSequences.put(sym, Word.fromList(asBuilder));
                    continue;
                }
                if (asBuilder.size() >= currentAccSeq.size()) continue;
                this.accessSequences.put(sym, Word.fromList(asBuilder));
                continue;
            }
            if (!this.inputAlphabet.isReturnSymbol(sym)) continue;
            int callIdx = this.inputAlphabet.findCallIndex(asBuilder, asBuilder.size() - 1);
            Object procedure = asBuilder.get(callIdx);
            Word<I> ts = this.terminatingSequences.get(procedure);
            assert (ts != null);
            asBuilder.subList(callIdx + 1, asBuilder.size()).clear();
            asBuilder.addAll(ts.asList());
            asBuilder.add(this.inputAlphabet.getReturnSymbol());
        }
    }

    private Word<I> minifyWellMatched(Word<I> input) {
        if (input.isEmpty()) {
            return Word.epsilon();
        }
        WordBuilder<I> wb = new WordBuilder<I>(input.size());
        for (int i = 0; i < input.size(); ++i) {
            int returnIdx;
            I sym = input.getSymbol(i);
            wb.append(sym);
            if (!this.inputAlphabet.isCallSymbol(sym) || (returnIdx = this.inputAlphabet.findReturnIndex(input, i + 1)) <= -1) continue;
            wb.append(this.terminatingSequences.get(sym));
            wb.append(this.inputAlphabet.getReturnSymbol());
            i = returnIdx;
        }
        return wb.toWord();
    }
}

