/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.transforming;

import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.neural.networks.computation.iteration.modes.Topologic;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.neurons.NeuronVisitor;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.types.DetailedNetwork;
import cz.cvut.fel.ida.neural.networks.structure.transforming.NetworkReducing;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.logging.Logger;

public class ParallelEdgeMerger
implements NetworkReducing {
    private static final Logger LOG = Logger.getLogger(ParallelEdgeMerger.class.getName());
    private Settings settings;

    public ParallelEdgeMerger(Settings settings) {
        this.settings = settings;
    }

    @Override
    public NeuralNetwork reduce(DetailedNetwork<State.Structure> inet, QueryNeuron outputStart) {
        MergingVisitor mergingVisitor = new MergingVisitor(inet);
        Topologic topologic = new Topologic(inet);
        topologic.getClass();
        Topologic.TDownVisitor bUpVisitor = new Topologic.TDownVisitor(topologic, outputStart.neuron, mergingVisitor);
        bUpVisitor.topdown();
        return inet;
    }

    @Override
    public NeuralNetwork reduce(DetailedNetwork<State.Structure> inet, List<QueryNeuron> outputStart) {
        return null;
    }

    @Override
    public void finish() {
    }

    class MergingVisitor
    extends NeuronVisitor.Weighted {
        public MergingVisitor(NeuralNetwork<State.Structure> network) {
            super(network, null, null);
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(BaseNeuron<T, S> neuron) {
            if (((ParallelEdgeMerger)ParallelEdgeMerger.this).settings.removeIdenticalUnweightedInputs) {
                if (!neuron.getAggregation().isInputSymmetric()) {
                    return;
                }
                HashSet<Neurons> inputNeurons = new HashSet<Neurons>();
                Iterator<T> inputs = this.network.getInputs(neuron);
                while (inputs.hasNext()) {
                    Neurons next = (Neurons)inputs.next();
                    if (inputNeurons.contains(next)) {
                        inputs.remove();
                        continue;
                    }
                    inputNeurons.add(next);
                }
            }
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(WeightedNeuron<T, S> neuron) {
            if (!((ParallelEdgeMerger)ParallelEdgeMerger.this).settings.mergeIdenticalWeightedInputs) {
                return;
            }
            if (!neuron.getAggregation().isInputSymmetric()) {
                return;
            }
            Pair<Iterator<T>, Iterator<Weight>> pair = this.network.getInputs(neuron);
            Iterator inputNeurons = (Iterator)pair.r;
            Iterator inputWeights = (Iterator)pair.s;
            LinkedHashMap<Neurons, List> allInputs = new LinkedHashMap<Neurons, List>();
            while (inputNeurons.hasNext()) {
                Neurons next = (Neurons)inputNeurons.next();
                Weight nextWeight = (Weight)inputWeights.next();
                if (nextWeight.isLearnable()) {
                    return;
                }
                List acumWeight = allInputs.getOrDefault(next, new ArrayList());
                acumWeight.add(nextWeight);
                allInputs.put(next, acumWeight);
            }
            HashSet<Neurons> visited = new HashSet<Neurons>();
            pair = this.network.getInputs(neuron);
            inputNeurons = (Iterator)pair.r;
            inputWeights = (Iterator)pair.s;
            while (inputNeurons.hasNext()) {
                Neurons next = (Neurons)inputNeurons.next();
                Weight nextWeight = (Weight)inputWeights.next();
                if (visited.contains(next)) {
                    inputNeurons.remove();
                    inputWeights.remove();
                    continue;
                }
                Weight finalWeight = this.mergeWeights((List)allInputs.get(next));
                if (finalWeight == null || finalWeight.equals(nextWeight)) continue;
                ((DetailedNetwork)this.network).replaceInputWeight(neuron, next, finalWeight);
                visited.add(next);
            }
        }

        private Weight mergeWeights(List<Weight> weightList) {
            if (weightList.size() == 1) {
                return weightList.get(0);
            }
            Value sum = null;
            StringBuilder sb = new StringBuilder("MERGED:");
            for (Weight weight : weightList) {
                if (weight.isShared) {
                    return null;
                }
                if (weight.isLearnable()) {
                    return null;
                }
                sb.append("_").append(weight.name);
                if (sum == null) {
                    if (weight.value == Value.ONE) {
                        sum = new ScalarValue(1.0);
                        continue;
                    }
                    sum = weight.value.clone();
                    continue;
                }
                sum.incrementBy(weight.value);
            }
            Weight finalWeight = new Weight(-2, sb.toString(), sum, true, true);
            return finalWeight;
        }
    }
}

