/*
 * Decompiled with CFR 0.152.
 */
package pipelines.building;

import constructs.example.LogicSample;
import constructs.template.Template;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Stream;
import learning.LearningSample;
import learning.Model;
import learning.crossvalidation.Crossvalidation;
import learning.crossvalidation.TrainTestResults;
import learning.crossvalidation.splitting.Splitter;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.structure.building.Neuralizer;
import pipelines.Branch;
import pipelines.Merge;
import pipelines.MultiBranch;
import pipelines.MultiMerge;
import pipelines.Pipe;
import pipelines.Pipeline;
import pipelines.building.AbstractPipelineBuilder;
import pipelines.building.GroundingBuilder;
import pipelines.building.NeuralNetsBuilder;
import pipelines.building.SamplesProcessingBuilder;
import pipelines.building.TemplateProcessingBuilder;
import pipelines.building.TrainTestBuilder;
import pipelines.pipes.generic.DuplicateBranch;
import pipelines.pipes.generic.DuplicateListBranch;
import pipelines.pipes.generic.IdentityGenPipe;
import pipelines.pipes.generic.ListBranch;
import pipelines.pipes.generic.ListMerge;
import pipelines.pipes.generic.PairMerge;
import pipelines.pipes.specific.TemplateToNeuralPipe;
import settings.Settings;
import settings.Source;
import settings.Sources;
import utils.generic.Pair;

public class CrossvalidationBuilder
extends AbstractPipelineBuilder<Sources, TrainTestResults> {
    private static final Logger LOG = Logger.getLogger(CrossvalidationBuilder.class.getName());
    private Sources sources;

    public CrossvalidationBuilder(Settings settings, Sources sources) {
        super(settings);
        this.sources = sources;
    }

    @Override
    public Pipeline<Sources, TrainTestResults> buildPipeline() {
        return this.buildPipeline(this.sources);
    }

    public Pipeline<Sources, TrainTestResults> buildPipeline(Sources sources) {
        Pipeline<Sources, TrainTestResults> pipeline = new Pipeline<Sources, TrainTestResults>("CrossvalidationPipeline", this);
        1 resultsMultiMerge = pipeline.registerEnd(new MultiMerge<TrainTestResults, TrainTestResults>("ResultsAggregateMerge", this.settings.foldsCount, this.settings){

            @Override
            protected TrainTestResults merge(List<TrainTestResults> inputs) {
                Crossvalidation crossvalidation = new Crossvalidation(this.settings);
                return crossvalidation.aggregateResults(inputs);
            }
        });
        TrainTestBuilder trainTestBuilder = new TrainTestBuilder(this.settings, sources);
        if (sources.foldFiles) {
            2 foldsBranch = pipeline.registerStart(new MultiBranch<Sources, Sources>("FoldsBranch", sources.folds.size(), this.settings){

                @Override
                protected List<Sources> branch(Sources folds) {
                    return folds.folds;
                }
            });
            if (sources.folds.stream().allMatch(fold -> fold.trainTest)) {
                List trainTestPipelines = trainTestBuilder.buildPipelines(sources.folds.size());
                trainTestPipelines.forEach(pipeline::register);
                foldsBranch.connectAfter(trainTestPipelines);
                resultsMultiMerge.connectBefore(trainTestPipelines);
            } else if (sources.folds.stream().allMatch(fold -> fold.testOnly)) {
                List testFoldPipes = new Pipe<Sources, Source>("GetTestFold"){

                    @Override
                    public Source apply(Sources sources) {
                        return sources.test;
                    }
                }.parallel(sources.folds.size());
                testFoldPipes.forEach(pipeline::register);
                List samplesExtract = new SamplesProcessingBuilder(this.settings, sources.folds.get((int)0).test).buildPipelines(sources.folds.size());
                samplesExtract.forEach(pipeline::register);
                Pipe.connect(testFoldPipes, samplesExtract);
                if (sources.templateProvided) {
                    DuplicateBranch duplicateOrigin = pipeline.registerStart(new DuplicateBranch());
                    TemplateProcessingBuilder templateProcessor = new TemplateProcessingBuilder(this.settings, sources);
                    Pipeline<Sources, Template> sourcesTemplatePipeline = null;
                    List sourcesTemplatePipelines = null;
                    if (this.settings.commonTemplate) {
                        sourcesTemplatePipeline = pipeline.register(templateProcessor.buildPipeline());
                        duplicateOrigin.connectAfterL(sourcesTemplatePipeline);
                        duplicateOrigin.connectAfterR(foldsBranch);
                        foldsBranch.connectAfter(testFoldPipes);
                    } else {
                        sourcesTemplatePipelines = templateProcessor.buildPipelines(sources.folds.size());
                        sourcesTemplatePipelines.forEach(pipeline::register);
                        List parallelSourcesDuplicates = duplicateOrigin.parallel(sources.folds.size());
                        parallelSourcesDuplicates.forEach(pipeline::register);
                        pipeline.registerStart(foldsBranch);
                        foldsBranch.connectAfter(parallelSourcesDuplicates);
                        Branch.connectAfterL(parallelSourcesDuplicates, testFoldPipes);
                        Branch.connectAfterR(parallelSourcesDuplicates, sourcesTemplatePipelines);
                    }
                    if (this.settings.trainFoldsIsolation) {
                        TrainTestBuilder trainTestBuilder2 = trainTestBuilder;
                        trainTestBuilder2.getClass();
                        List logicTrainTestPipelines = new TrainTestBuilder.LogicTrainTestBuilder(trainTestBuilder2, this.settings).buildPipelines(sources.folds.size());
                        logicTrainTestPipelines.forEach(pipeline::register);
                        resultsMultiMerge.connectBefore(logicTrainTestPipelines);
                        MultiMerge<Stream<LogicSample>, Crossvalidation<LogicSample>> logicCrossvalidation = pipeline.register(this.assembleCV(LogicSample.class, sources.folds.size()));
                        logicCrossvalidation.connectBefore(samplesExtract);
                        if (this.settings.commonTemplate) {
                            PairMerge templateCVmerge = pipeline.register(new PairMerge());
                            templateCVmerge.connectBeforeL(sourcesTemplatePipeline);
                            templateCVmerge.connectBeforeR(logicCrossvalidation);
                            MultiBranch<Pair<Template, Crossvalidation<LogicSample>>, Pair<Template, Pair<Stream<LogicSample>, Stream<LogicSample>>>> emitFolds = pipeline.register(this.emitModelFolds(Template.class, LogicSample.class, sources.folds.size()));
                            templateCVmerge.connectAfter(emitFolds);
                            emitFolds.connectAfter(logicTrainTestPipelines);
                        } else {
                            ListMerge templatesMergeList = pipeline.register(new ListMerge(sources.folds.size(), this.settings));
                            templatesMergeList.connectBefore(sourcesTemplatePipelines);
                            Merge<List<Template>, Crossvalidation<LogicSample>, List<Pair<Template, Pair<Stream<LogicSample>, Stream<LogicSample>>>>> mergeCVtemplates = pipeline.register(this.emitModelsFolds(Template.class, LogicSample.class, sources.folds.size()));
                            mergeCVtemplates.connectBeforeL(templatesMergeList);
                            mergeCVtemplates.connectBeforeR(logicCrossvalidation);
                            ListBranch modelsFoldsBranch = pipeline.register(new ListBranch(sources.folds.size(), this.settings));
                            mergeCVtemplates.connectAfter(modelsFoldsBranch);
                            modelsFoldsBranch.connectAfter(logicTrainTestPipelines);
                        }
                    } else {
                        TrainTestBuilder trainTestBuilder3 = trainTestBuilder;
                        trainTestBuilder3.getClass();
                        List neuralTrainTestPipelines = new TrainTestBuilder.NeuralTrainTestBuilder(trainTestBuilder3, this.settings).buildPipelines(sources.folds.size());
                        neuralTrainTestPipelines.forEach(pipeline::register);
                        resultsMultiMerge.connectBefore(neuralTrainTestPipelines);
                        MultiMerge<Stream<NeuralSample>, Crossvalidation<NeuralSample>> neuralCrossvalidation = pipeline.register(this.assembleCV(NeuralSample.class, sources.folds.size()));
                        GroundingBuilder groundingBuilder = new GroundingBuilder(this.settings);
                        List groundingPipelines = groundingBuilder.buildPipelines(sources.folds.size());
                        NeuralNetsBuilder neuralNetsBuilder = new NeuralNetsBuilder(this.settings, new Neuralizer(this.settings, groundingBuilder.grounder.weightFactory));
                        List neuralizationPipelines = neuralNetsBuilder.buildPipelines(sources.folds.size());
                        groundingPipelines.forEach(pipeline::register);
                        neuralizationPipelines.forEach(pipeline::register);
                        List templateSamplesMerges = new PairMerge().parallel(sources.folds.size());
                        templateSamplesMerges.forEach(pipeline::register);
                        Pipe.connect(groundingPipelines, neuralizationPipelines);
                        neuralCrossvalidation.connectBefore(neuralizationPipelines);
                        Merge.connectBeforeR(templateSamplesMerges, samplesExtract);
                        Pipe.connect(templateSamplesMerges, groundingPipelines);
                        if (this.settings.commonTemplate) {
                            DuplicateBranch duplicateTemplate = pipeline.register(new DuplicateBranch());
                            TemplateToNeuralPipe templateToNeuralPipe = pipeline.register(new TemplateToNeuralPipe(this.settings));
                            DuplicateListBranch templateListBranch = pipeline.register(new DuplicateListBranch(sources.folds.size(), this.settings));
                            sourcesTemplatePipeline.connectAfter(duplicateTemplate);
                            duplicateTemplate.connectAfterL(templateListBranch);
                            duplicateTemplate.connectAfterR(templateToNeuralPipe);
                            Merge.connectBeforeL(templateSamplesMerges, templateListBranch.outputs);
                            PairMerge pairMerge = pipeline.register(new PairMerge());
                            pairMerge.connectBeforeL(templateToNeuralPipe);
                            pairMerge.connectBeforeR(neuralCrossvalidation);
                            MultiBranch<Pair<NeuralModel, Crossvalidation<NeuralSample>>, Pair<NeuralModel, Pair<Stream<NeuralSample>, Stream<NeuralSample>>>> emitFolds = pipeline.register(this.emitModelFolds(NeuralModel.class, NeuralSample.class, sources.folds.size()));
                            pairMerge.connectAfter(emitFolds);
                            emitFolds.connectAfter(neuralTrainTestPipelines);
                        } else {
                            List parallelTemplateBranches = new DuplicateBranch().parallel(sources.folds.size());
                            parallelTemplateBranches.forEach(pipeline::register);
                            List templatesPipes = new IdentityGenPipe().parallel(sources.folds.size());
                            templatesPipes.forEach(pipeline::register);
                            List template2NeuralModels = new TemplateToNeuralPipe(this.settings).parallel(sources.folds.size());
                            template2NeuralModels.forEach(pipeline::register);
                            ListMerge neuralModelsMerge = pipeline.register(new ListMerge(sources.folds.size(), this.settings));
                            Merge<List<NeuralModel>, Crossvalidation<NeuralSample>, List<Pair<NeuralModel, Pair<Stream<NeuralSample>, Stream<NeuralSample>>>>> emitModelsFolds = pipeline.register(this.emitModelsFolds(NeuralModel.class, NeuralSample.class, sources.folds.size()));
                            ListBranch modelsFoldsBranch = pipeline.register(new ListBranch(sources.folds.size(), this.settings));
                            Pipe.connect(sourcesTemplatePipelines, parallelTemplateBranches);
                            Branch.connectAfterL(parallelTemplateBranches, templatesPipes);
                            Branch.connectAfterR(parallelTemplateBranches, template2NeuralModels);
                            Merge.connectBeforeL(templateSamplesMerges, templatesPipes);
                            neuralModelsMerge.connectBefore(template2NeuralModels);
                            emitModelsFolds.connectBeforeL(neuralModelsMerge);
                            emitModelsFolds.connectBeforeR(neuralCrossvalidation);
                            emitModelsFolds.connectAfter(modelsFoldsBranch);
                            modelsFoldsBranch.connectAfter(neuralTrainTestPipelines);
                        }
                    }
                } else {
                    MultiMerge<Stream<LogicSample>, Crossvalidation<LogicSample>> mergeFolds2CV = pipeline.register(this.assembleCV(LogicSample.class, sources.folds.size()));
                    MultiBranch<Crossvalidation<LogicSample>, Pair<Stream<LogicSample>, Stream<LogicSample>>> emitTrainTest = pipeline.register(this.emitTrainTest(LogicSample.class, sources.folds.size()));
                    TrainTestBuilder trainTestBuilder4 = trainTestBuilder;
                    trainTestBuilder4.getClass();
                    List structTrainTestPipeline = new TrainTestBuilder.StructureTrainTestBuilder(trainTestBuilder4, this.settings).buildPipelines(sources.folds.size());
                    structTrainTestPipeline.forEach(pipeline::register);
                    mergeFolds2CV.connectBefore(samplesExtract);
                    emitTrainTest.connectBefore(mergeFolds2CV);
                    emitTrainTest.connectAfter(structTrainTestPipeline);
                    resultsMultiMerge.connectBefore(structTrainTestPipeline);
                }
            }
        } else {
            4 getTrainSource = pipeline.register(new Pipe<Sources, Source>("GetTrainFold"){

                @Override
                public Source apply(Sources sources) {
                    return sources.train;
                }
            });
            Pipeline<Source, Stream<LogicSample>> samplesExtract = pipeline.register(new SamplesProcessingBuilder(this.settings, sources.train).buildPipeline(sources.train));
            if (sources.templateProvided) {
                Pipeline<Sources, Template> getTemplate = pipeline.register(new TemplateProcessingBuilder(this.settings, sources).buildPipeline());
                DuplicateBranch sourcesDuplicateBranch = pipeline.registerStart(new DuplicateBranch());
                sourcesDuplicateBranch.connectAfterL(getTemplate);
                sourcesDuplicateBranch.connectAfterR(getTrainSource);
                getTrainSource.connectAfter(samplesExtract);
                if (this.settings.trainFoldsIsolation) {
                    TrainTestBuilder trainTestBuilder5 = trainTestBuilder;
                    trainTestBuilder5.getClass();
                    List logicTrainTestPipelines = new TrainTestBuilder.LogicTrainTestBuilder(trainTestBuilder5, this.settings).buildPipelines(this.settings.foldsCount);
                    Pipe<Stream<LogicSample>, Crossvalidation<LogicSample>> logicCrossvalidation = pipeline.register(this.cvFromStream(LogicSample.class));
                    logicTrainTestPipelines.forEach(pipeline::register);
                    PairMerge templateCVmerge = pipeline.register(new PairMerge());
                    MultiBranch<Pair<Template, Crossvalidation<LogicSample>>, Pair<Template, Pair<Stream<LogicSample>, Stream<LogicSample>>>> emitFolds = pipeline.register(this.emitModelFolds(Template.class, LogicSample.class, this.settings.foldsCount));
                    samplesExtract.connectAfter(logicCrossvalidation);
                    templateCVmerge.connectBeforeL(getTemplate);
                    templateCVmerge.connectBeforeR(logicCrossvalidation);
                    templateCVmerge.connectAfter(emitFolds);
                    emitFolds.connectAfter(logicTrainTestPipelines);
                    resultsMultiMerge.connectBefore(logicTrainTestPipelines);
                } else {
                    TrainTestBuilder trainTestBuilder6 = trainTestBuilder;
                    trainTestBuilder6.getClass();
                    List neuralTrainTestPipelines = new TrainTestBuilder.NeuralTrainTestBuilder(trainTestBuilder6, this.settings).buildPipelines(this.settings.foldsCount);
                    neuralTrainTestPipelines.forEach(pipeline::register);
                    GroundingBuilder groundingBuilder = new GroundingBuilder(this.settings);
                    List groundingPipelines = groundingBuilder.buildPipelines(this.settings.foldsCount);
                    groundingPipelines.forEach(pipeline::register);
                    NeuralNetsBuilder neuralNetsBuilder = new NeuralNetsBuilder(this.settings, new Neuralizer(this.settings, groundingBuilder.grounder.weightFactory));
                    List neuralizationPipelines = neuralNetsBuilder.buildPipelines(this.settings.foldsCount);
                    neuralizationPipelines.forEach(pipeline::register);
                    DuplicateBranch duplicateTemplate = pipeline.register(new DuplicateBranch());
                    DuplicateListBranch templateListBranch = pipeline.register(new DuplicateListBranch(this.settings.foldsCount, this.settings));
                    ListBranch foldsBranch = pipeline.register(new ListBranch(this.settings.foldsCount, this.settings));
                    5 samplesSplitterPipe = pipeline.register(new Pipe<Stream<LogicSample>, List<Stream<LogicSample>>>("SplitterPipe", this.settings){

                        @Override
                        public List<Stream<LogicSample>> apply(Stream<LogicSample> logicSampleStream) {
                            Splitter<LogicSample> splitter = Splitter.getSplitter(this.settings);
                            return splitter.partition(logicSampleStream, this.settings.foldsCount);
                        }
                    });
                    List templateSamplesMerges = new PairMerge().parallel(this.settings.foldsCount);
                    templateSamplesMerges.forEach(pipeline::register);
                    MultiMerge<Stream<NeuralSample>, Crossvalidation<NeuralSample>> neuralCrossvalidation = pipeline.register(this.assembleCV(NeuralSample.class, this.settings.foldsCount));
                    TemplateToNeuralPipe templateToNeuralPipe = pipeline.register(new TemplateToNeuralPipe(this.settings));
                    PairMerge modelCVmerge = pipeline.register(new PairMerge());
                    MultiBranch<Pair<NeuralModel, Crossvalidation<NeuralSample>>, Pair<NeuralModel, Pair<Stream<NeuralSample>, Stream<NeuralSample>>>> emitFolds = this.emitModelFolds(NeuralModel.class, NeuralSample.class, this.settings.foldsCount);
                    getTemplate.connectAfter(duplicateTemplate);
                    duplicateTemplate.connectAfterL(templateListBranch);
                    samplesExtract.connectAfter(samplesSplitterPipe);
                    samplesSplitterPipe.connectAfter(foldsBranch);
                    Merge.connectBeforeL(templateSamplesMerges, templateListBranch.outputs);
                    Merge.connectBeforeR(templateSamplesMerges, foldsBranch.outputs);
                    Pipe.connect(templateSamplesMerges, groundingPipelines);
                    Pipe.connect(groundingPipelines, neuralizationPipelines);
                    neuralCrossvalidation.connectBefore(neuralizationPipelines);
                    duplicateTemplate.connectAfterR(templateToNeuralPipe);
                    modelCVmerge.connectBeforeL(templateToNeuralPipe);
                    modelCVmerge.connectBeforeR(neuralCrossvalidation);
                    modelCVmerge.connectAfter(emitFolds);
                    emitFolds.connectAfter(neuralTrainTestPipelines);
                    resultsMultiMerge.connectBefore(neuralTrainTestPipelines);
                }
            } else {
                Pipe<Stream<LogicSample>, Crossvalidation<LogicSample>> cvFromStream = pipeline.register(this.cvFromStream(LogicSample.class));
                MultiBranch<Crossvalidation<LogicSample>, Pair<Stream<LogicSample>, Stream<LogicSample>>> emitTrainTest = pipeline.register(this.emitTrainTest(LogicSample.class, sources.folds.size()));
                TrainTestBuilder trainTestBuilder7 = trainTestBuilder;
                trainTestBuilder7.getClass();
                List structTrainTestPipelines = new TrainTestBuilder.StructureTrainTestBuilder(trainTestBuilder7, this.settings).buildPipelines(sources.folds.size());
                structTrainTestPipelines.forEach(pipeline::register);
                cvFromStream.connectBefore(samplesExtract);
                emitTrainTest.connectBefore(cvFromStream);
                emitTrainTest.connectAfter(structTrainTestPipelines);
                resultsMultiMerge.connectBefore(structTrainTestPipelines);
            }
        }
        return pipeline;
    }

    protected <S extends LearningSample> Pipe<Stream<S>, Crossvalidation<S>> cvFromStream(Class<S> s) {
        return new Pipe<Stream<S>, Crossvalidation<S>>("CVFromStream"){

            @Override
            public Crossvalidation<S> apply(Stream<S> logicSampleStream) {
                Crossvalidation cv = new Crossvalidation(this.settings, this.settings.foldsCount);
                cv.loadFolds(logicSampleStream);
                return cv;
            }
        };
    }

    protected <S extends LearningSample> MultiMerge<Stream<S>, Crossvalidation<S>> assembleCV(Class<S> s, int foldsCount) {
        return new MultiMerge<Stream<S>, Crossvalidation<S>>("MergeFolds2CV", foldsCount, this.settings){

            @Override
            protected Crossvalidation<S> merge(List<Stream<S>> inputs) {
                Crossvalidation cv = new Crossvalidation(this.settings, inputs.size());
                cv.loadFolds(inputs);
                return cv;
            }
        };
    }

    protected <S extends LearningSample> MultiBranch<Crossvalidation<S>, Pair<Stream<S>, Stream<S>>> emitTrainTest(Class<S> s, int foldCount) {
        return new MultiBranch<Crossvalidation<S>, Pair<Stream<S>, Stream<S>>>("BranchCV2Folds", this.sources.folds.size(), this.settings){

            @Override
            protected List<Pair<Stream<S>, Stream<S>>> branch(Crossvalidation<S> cv) {
                ArrayList folds = new ArrayList(cv.foldCount);
                for (int i = 0; i < cv.foldCount; ++i) {
                    folds.add(new Pair(cv.folds.get((int)i).train.stream(), cv.folds.get((int)i).test.stream()));
                }
                return folds;
            }
        };
    }

    protected <T extends Model, S extends LearningSample> MultiBranch<Pair<T, Crossvalidation<S>>, Pair<T, Pair<Stream<S>, Stream<S>>>> emitModelFolds(Class<T> t, Class<S> s, int foldCount) {
        return new MultiBranch<Pair<T, Crossvalidation<S>>, Pair<T, Pair<Stream<S>, Stream<S>>>>("EmitFoldsWithModel", foldCount, this.settings){

            @Override
            protected List<Pair<T, Pair<Stream<S>, Stream<S>>>> branch(Pair<T, Crossvalidation<S>> cv) {
                ArrayList pairList = new ArrayList(((Crossvalidation)cv.s).foldCount);
                for (int i = 0; i < ((Crossvalidation)cv.s).foldCount; ++i) {
                    pairList.add(new Pair(cv.r, new Pair(((Crossvalidation)cv.s).folds.get((int)i).train.stream(), ((Crossvalidation)cv.s).folds.get((int)i).test.stream())));
                }
                return pairList;
            }
        };
    }

    protected <T extends Model, S extends LearningSample> Merge<List<T>, Crossvalidation<S>, List<Pair<T, Pair<Stream<S>, Stream<S>>>>> emitModelsFolds(Class<T> t, Class<S> s, int foldsCount) {
        return new Merge<List<T>, Crossvalidation<S>, List<Pair<T, Pair<Stream<S>, Stream<S>>>>>("EmitFoldsWithModels"){

            @Override
            protected List<Pair<T, Pair<Stream<S>, Stream<S>>>> merge(List<T> templates, Crossvalidation<S> cv) {
                ArrayList folds = new ArrayList(templates.size());
                for (int i = 0; i < templates.size(); ++i) {
                    folds.add(new Pair(templates.get(i), new Pair(cv.folds.get((int)i).train.stream(), cv.folds.get((int)i).test.stream())));
                }
                return folds;
            }
        };
    }
}

