/*
 * Decompiled with CFR 0.152.
 */
package de.learnlib.algorithm.ostia;

import de.learnlib.algorithm.PassiveLearningAlgorithm;
import de.learnlib.algorithm.ostia.Blue;
import de.learnlib.algorithm.ostia.Edge;
import de.learnlib.algorithm.ostia.IntQueue;
import de.learnlib.algorithm.ostia.OSSTWrapper;
import de.learnlib.algorithm.ostia.Out;
import de.learnlib.algorithm.ostia.State;
import de.learnlib.algorithm.ostia.StateCopy;
import de.learnlib.query.DefaultQuery;
import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
import net.automatalib.alphabet.Alphabet;
import net.automatalib.alphabet.GrowingAlphabet;
import net.automatalib.alphabet.GrowingMapAlphabet;
import net.automatalib.automaton.transducer.SubsequentialTransducer;
import net.automatalib.common.smartcollection.IntSeq;
import net.automatalib.common.util.Pair;
import net.automatalib.word.Word;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;

public class OSTIA<I, O>
implements PassiveLearningAlgorithm<SubsequentialTransducer<?, I, ?, O>, I, Word<O>> {
    private final Alphabet<I> inputAlphabet;
    private final GrowingAlphabet<O> outputAlphabet;
    private final State root;
    private boolean hasBeenComputed;

    public OSTIA(Alphabet<I> inputAlphabet) {
        this.inputAlphabet = inputAlphabet;
        this.outputAlphabet = new GrowingMapAlphabet<O>();
        this.root = new State(inputAlphabet.size());
        this.hasBeenComputed = false;
    }

    @Override
    public void addSamples(Collection<? extends DefaultQuery<I, Word<O>>> samples) {
        for (DefaultQuery<I, Word<O>> sample : samples) {
            Word<O> output = sample.getOutput();
            this.outputAlphabet.addAll(output.asList());
            OSTIA.buildPttOnward(this.root, sample.getInput().asIntSeq(this.inputAlphabet), IntQueue.asQueue(output.asIntSeq(this.outputAlphabet)));
        }
    }

    @Override
    public SubsequentialTransducer<?, I, ?, O> computeModel() {
        if (!this.hasBeenComputed) {
            this.hasBeenComputed = true;
            OSTIA.ostia(this.root);
        }
        return new OSSTWrapper<I, O>(this.root, this.inputAlphabet, this.outputAlphabet);
    }

    public static State buildPtt(int alphabetSize, Iterator<Pair<IntSeq, IntSeq>> informant) {
        State root = new State(alphabetSize);
        while (informant.hasNext()) {
            Pair<IntSeq, IntSeq> inout = informant.next();
            OSTIA.buildPttOnward(root, inout.getFirst(), IntQueue.asQueue(inout.getSecond()));
        }
        return root;
    }

    private static void buildPttOnward(State ptt, IntSeq input, @Nullable IntQueue output) {
        State pttIter = ptt;
        IntQueue outputIter = output;
        for (int i = 0; i < input.size(); ++i) {
            Edge edge;
            int symbol = input.get(i);
            if (pttIter.transitions[symbol] == null) {
                edge = new Edge();
                edge.out = outputIter;
                edge.target = new State(pttIter.transitions.length);
                pttIter.transitions[symbol] = edge;
                outputIter = null;
            } else {
                edge = pttIter.transitions[symbol];
                IntQueue commonPrefixEdge = edge.out;
                IntQueue commonPrefixEdgePrev = null;
                IntQueue commonPrefixInformant = outputIter;
                while (commonPrefixEdge != null && commonPrefixInformant != null && commonPrefixEdge.value == commonPrefixInformant.value) {
                    commonPrefixInformant = commonPrefixInformant.next;
                    commonPrefixEdgePrev = commonPrefixEdge;
                    commonPrefixEdge = commonPrefixEdge.next;
                }
                if (commonPrefixEdgePrev == null) {
                    edge.out = null;
                } else {
                    commonPrefixEdgePrev.next = null;
                }
                edge.target.prependButIgnoreMissingStateOutput(commonPrefixEdge);
                outputIter = commonPrefixInformant;
            }
            pttIter = edge.target;
        }
        if (pttIter.out != null && !IntQueue.eq(pttIter.out.str, outputIter)) {
            throw new IllegalArgumentException("For input '" + input + "' the state output is '" + pttIter.out + "' but training sample has remaining suffix '" + outputIter + '\'');
        }
        pttIter.out = new Out(outputIter);
    }

    private static void addBlueStates(State parent, Queue<Blue> blue) {
        for (int i = 0; i < parent.transitions.length; ++i) {
            Edge transition = parent.transitions[i];
            if (transition == null) continue;
            assert (!OSTIA.contains(blue, transition.target));
            assert (transition.target != parent);
            blue.add(new Blue(parent, i));
        }
    }

    public static void ostia(State transducer) {
        LinkedList<Blue> blue = new LinkedList<Blue>();
        LinkedHashSet<State> red = new LinkedHashSet<State>();
        assert (OSTIA.isTree(transducer, new HashSet<State>()));
        red.add(transducer);
        OSTIA.addBlueStates(transducer, blue);
        assert (OSTIA.uniqueItems(blue));
        assert (OSTIA.disjoint(blue, red));
        assert (OSTIA.validateBlueAndRed(transducer, red, blue));
        block0: while (!blue.isEmpty()) {
            @NonNull Blue next = (Blue)blue.poll();
            @Nullable State blueState = next.state();
            assert (blueState != null);
            assert (OSTIA.isTree(blueState, new HashSet<State>()));
            assert (OSTIA.uniqueItems(blue));
            assert (!OSTIA.contains(blue, blueState));
            assert (OSTIA.disjoint(blue, red));
            for (State redState : red) {
                if (!OSTIA.ostiaMerge(next, redState, blue, red)) continue;
                assert (OSTIA.disjoint(blue, red));
                assert (OSTIA.uniqueItems(blue));
                continue block0;
            }
            assert (OSTIA.isTree(blueState, new HashSet<State>()));
            assert (OSTIA.uniqueItems(blue));
            OSTIA.addBlueStates(blueState, blue);
            assert (OSTIA.uniqueItems(blue));
            assert (!OSTIA.contains(blue, blueState));
            assert (OSTIA.disjoint(blue, red));
            red.add(blueState);
            assert (OSTIA.disjoint(blue, red));
            assert (OSTIA.validateBlueAndRed(transducer, red, blue));
        }
    }

    private static boolean ostiaMerge(Blue blue, State redState, Queue<Blue> blueToVisit, Set<State> red) {
        HashMap<State, StateCopy> merged = new HashMap<State, StateCopy>();
        ArrayList<Blue> reachedBlueStates = new ArrayList<Blue>();
        if (OSTIA.ostiaFold(redState, null, blue.parent, blue.symbol, merged, reachedBlueStates)) {
            for (Map.Entry mergedRedState : merged.entrySet()) {
                assert (mergedRedState.getKey() == ((StateCopy)mergedRedState.getValue()).original);
                ((StateCopy)mergedRedState.getValue()).assign();
            }
            for (Blue reachedBlueCandidate : reachedBlueStates) {
                if (!red.contains(reachedBlueCandidate.parent)) continue;
                assert (!OSTIA.contains(blueToVisit, reachedBlueCandidate.state()));
                blueToVisit.add(reachedBlueCandidate);
            }
            return true;
        }
        return false;
    }

    private static boolean ostiaFold(State red, @Nullable IntQueue pushedBack, State blueParent, int symbolIncomingToBlue, Map<State, StateCopy> mergedStates, List<Blue> reachedBlueStates) {
        Edge incomingTransition = blueParent.transitions[symbolIncomingToBlue];
        assert (incomingTransition != null);
        State blueState = incomingTransition.target;
        assert (red != blueState);
        assert (!mergedStates.containsKey(blueState));
        StateCopy mergedRedState = mergedStates.computeIfAbsent(red, StateCopy::new);
        StateCopy mergedBlueState = new StateCopy(blueState);
        Edge mergedIncomingTransition = mergedStates.computeIfAbsent((State)blueParent, (Function<State, StateCopy>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, <init>(de.learnlib.algorithm.ostia.State ), (Lde/learnlib/algorithm/ostia/State;)Lde/learnlib/algorithm/ostia/StateCopy;)()).transitions[symbolIncomingToBlue];
        assert (mergedIncomingTransition != null);
        mergedIncomingTransition.target = red;
        StateCopy prevBlue = mergedStates.put(blueState, mergedBlueState);
        assert (prevBlue == null);
        mergedBlueState.prepend(pushedBack);
        if (mergedBlueState.out != null) {
            if (mergedRedState.out == null) {
                mergedRedState.out = mergedBlueState.out;
            } else if (!IntQueue.eq(mergedRedState.out.str, mergedBlueState.out.str)) {
                return false;
            }
        }
        for (int i = 0; i < mergedRedState.transitions.length; ++i) {
            Edge transitionBlue = mergedBlueState.transitions[i];
            if (transitionBlue == null) continue;
            Edge transitionRed = mergedRedState.transitions[i];
            if (transitionRed == null) {
                mergedRedState.transitions[i] = new Edge(transitionBlue);
                reachedBlueStates.add(new Blue(red, i));
                continue;
            }
            IntQueue commonPrefixRed = transitionRed.out;
            IntQueue commonPrefixBlue = transitionBlue.out;
            IntQueue commonPrefixBluePrev = null;
            while (commonPrefixBlue != null && commonPrefixRed != null && commonPrefixBlue.value == commonPrefixRed.value) {
                commonPrefixBluePrev = commonPrefixBlue;
                commonPrefixBlue = commonPrefixBlue.next;
                commonPrefixRed = commonPrefixRed.next;
            }
            assert (commonPrefixBluePrev == null || commonPrefixBluePrev.next == commonPrefixBlue);
            if (commonPrefixRed == null) {
                if (commonPrefixBluePrev == null) {
                    transitionBlue.out = null;
                } else {
                    commonPrefixBluePrev.next = null;
                }
                assert (Objects.equals(Optional.ofNullable(mergedBlueState.transitions[i]).map(e -> e.target), Optional.ofNullable(blueState.transitions[i]).map(e -> e.target)));
                if (OSTIA.ostiaFold(transitionRed.target, commonPrefixBlue, blueState, i, mergedStates, reachedBlueStates)) continue;
                return false;
            }
            return false;
        }
        return true;
    }

    public static @Nullable IntSeq run(State init, IntSeq input) {
        ArrayList<Integer> output = new ArrayList<Integer>();
        State iter = init;
        for (int i = 0; i < input.size(); ++i) {
            Edge edge = iter.transitions[input.get(i)];
            if (edge == null) {
                return null;
            }
            iter = edge.target;
            IntQueue q = edge.out;
            while (q != null) {
                output.add(q.value);
                q = q.next;
            }
        }
        if (iter.out == null) {
            return null;
        }
        IntQueue q = iter.out.str;
        while (q != null) {
            output.add(q.value);
            q = q.next;
        }
        return IntSeq.of(output);
    }

    private static boolean disjoint(Queue<Blue> blue, Set<State> red) {
        for (Blue b : blue) {
            if (!red.contains(b.state())) continue;
            return false;
        }
        return true;
    }

    private static boolean contains(Queue<Blue> blue, @Nullable State state) {
        for (Blue b : blue) {
            if (!Objects.equals(state, b.state())) continue;
            return true;
        }
        return false;
    }

    private static boolean uniqueItems(Queue<Blue> blue) {
        HashSet<@Nullable State> unique = new HashSet<State>();
        for (Blue b : blue) {
            if (unique.add(b.state())) continue;
            return false;
        }
        return true;
    }

    private static boolean validateBlueAndRed(State root, Set<State> red, Queue<Blue> blue) {
        HashSet<State> reachable = new HashSet<State>();
        OSTIA.isTree(root, reachable);
        for (State r : red) {
            for (Edge edge : r.transitions) {
                assert (edge == null || OSTIA.contains(blue, edge.target) ^ red.contains(edge.target));
            }
            assert (reachable.contains(r));
        }
        for (Blue b : blue) {
            assert (red.contains(b.parent));
            assert (reachable.contains(b.state()));
        }
        return true;
    }

    private static boolean isTree(State root, Set<State> nodes) {
        ArrayDeque<State> toVisit = new ArrayDeque<State>();
        toVisit.add(root);
        boolean isTree = true;
        while (!toVisit.isEmpty()) {
            @NonNull State s2 = (State)toVisit.poll();
            if (nodes.add(s2)) {
                for (Edge edge : s2.transitions) {
                    if (edge == null) continue;
                    toVisit.add(edge.target);
                }
                continue;
            }
            isTree = false;
        }
        return isTree;
    }
}

