/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.pipelines.debugging;

import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.learning.results.Progress;
import cz.cvut.fel.ida.logic.constructs.template.Template;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.pipelines.Pipe;
import cz.cvut.fel.ida.pipelines.Pipeline;
import cz.cvut.fel.ida.pipelines.debugging.TemplateDebugger;
import cz.cvut.fel.ida.pipelines.pipes.generic.FirstFromPairPipe;
import cz.cvut.fel.ida.pipelines.pipes.generic.StreamifyPipe;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.setup.Sources;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.Map;
import java.util.function.Consumer;
import java.util.logging.Logger;
import java.util.stream.Stream;

public class TrainingDebugger
extends TemplateDebugger {
    private static final Logger LOG = Logger.getLogger(TrainingDebugger.class.getName());
    private Consumer<Map<Integer, Weight>> templateRedrawCallback;

    public TrainingDebugger(Sources sources, Settings settings) {
        super(sources, settings);
        if (this.intermediateDebug) {
            settings.debugPipeline = true;
            settings.debugTemplate = true;
            settings.debugSampleTraining = true;
            settings.debugGrounding = true;
        }
    }

    public TrainingDebugger(Settings settings, Template template) {
        super(settings);
        this.templateRedrawCallback = weightMap -> {
            template.updateWeightsFrom((Map<Integer, Weight>)weightMap);
            this.drawer.draw(template);
        };
        if (this.intermediateDebug) {
            settings.debugPipeline = true;
            settings.debugTemplate = true;
            settings.debugSampleTraining = true;
            settings.debugGrounding = true;
        }
    }

    @Override
    public Pipeline<Sources, Stream<Template>> buildPipeline() {
        Pipeline<Sources, Pair<Pair<Template, NeuralModel>, Progress>> sourcesPairPipeline = this.pipeline.registerStart(this.end2endTrainigBuilder.buildPipeline());
        FirstFromPairPipe pairTemplatePipe1 = this.pipeline.register(new FirstFromPairPipe("FirstFromPairPipe1"));
        sourcesPairPipeline.connectAfter(pairTemplatePipe1);
        Pipe pairTemplatePipe = this.pipeline.register(pairTemplatePipe1.connectAfter(new FirstFromPairPipe("FirstFromPairPipe2")));
        this.pipeline.registerEnd(pairTemplatePipe.connectAfter(new StreamifyPipe()));
        return this.pipeline;
    }

    public void debugWeights(Map<Integer, Weight> weightMap) {
        this.templateRedrawCallback.accept(weightMap);
    }
}

