/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.components;

import com.sun.istack.internal.NotNull;
import com.sun.istack.internal.Nullable;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.learning.Example;
import cz.cvut.fel.ida.neural.networks.computation.iteration.NeuronIterating;
import cz.cvut.fel.ida.neural.networks.computation.iteration.modes.DFSstack;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.neurons.NeuronVisitor;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.networks.InputsGetter;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.networks.OutputsGetter;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.StatesCache;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.logging.Logger;

public class NeuralNetwork<N extends State.Structure>
implements Example {
    private static final Logger LOG = Logger.getLogger(NeuralNetwork.class.getName());
    @NotNull
    protected String id;
    protected int neuronCount;
    int edgeCount;
    public boolean hasSharedNeurons;
    public boolean containsInputMasking;
    public boolean containsCrossProducts;
    @Nullable
    public StatesCache<N> neuronStates;
    InputsGetter inputsGetter;
    InputsGetter.Weighted weightedInputsGetter;
    OutputsGetter outputsGetter;

    public NeuralNetwork(String id, int size) {
        this.id = id;
        this.neuronCount = size;
    }

    @Override
    public String getId() {
        return this.id;
    }

    @Override
    public Integer getNeuronCount() {
        return this.neuronCount;
    }

    public void setId(String id) {
        this.id = id;
    }

    public N getState(Neurons neuron) {
        if (this.neuronStates != null) {
            return this.neuronStates.getState(neuron);
        }
        return null;
    }

    public void initializeStatesCache(int stateView) {
        if (this.neuronStates != null) {
            this.neuronStates.initialize(stateView);
        }
    }

    public <T extends Neurons, S extends State.Neural> Pair<Iterator<T>, Iterator<Weight>> getInputs(WeightedNeuron<T, S> neuron) {
        if (neuron.isShared) {
            N neuralState = this.getState(neuron);
            WeightedNeuronMapping visit = (WeightedNeuronMapping)this.weightedInputsGetter.visit(neuralState);
            Iterator iterator = visit.iterator();
            Iterator<Weight> weightIterator = visit.weightIterator();
            return new Pair(iterator, weightIterator);
        }
        return new Pair(neuron.getInputs().iterator(), neuron.getWeights().iterator());
    }

    public <T extends Neurons, S extends State.Neural> Pair<Iterator<T>, Iterator<Weight>> getInputs(WeightedNeuron<T, S> neuron, int[] inputMask) {
        ArrayList<Weight> weights;
        ArrayList<Object> inputs;
        if (neuron.isShared) {
            N neuralState = this.getState(neuron);
            WeightedNeuronMapping visit = (WeightedNeuronMapping)this.weightedInputsGetter.visit(neuralState);
            Iterator iterator = visit.iterator();
            Iterator<Weight> weightIterator = visit.weightIterator();
            inputs = new ArrayList();
            weights = new ArrayList();
            iterator.forEachRemaining(inputs::add);
            weightIterator.forEachRemaining(weights::add);
        } else {
            inputs = neuron.getInputs();
            weights = neuron.getWeights();
        }
        ArrayList<Object> maskedInputs = new ArrayList<Object>(inputMask.length);
        ArrayList<Weight> maskedWeights = new ArrayList<Weight>(inputMask.length);
        for (int i = 0; i < inputMask.length; ++i) {
            int i1 = inputMask[i];
            maskedInputs.add(inputs.get(i1));
            maskedWeights.add(weights.get(i1));
        }
        return new Pair<Iterator<T>, Iterator<Weight>>(maskedInputs.iterator(), maskedWeights.iterator());
    }

    public <T extends Neurons, S extends State.Neural> Iterator<T> getInputs(BaseNeuron<T, S> neuron, int[] inputMask) {
        ArrayList<Object> inputs;
        if (neuron.isShared) {
            N neuralState = this.getState(neuron);
            NeuronMapping visit = (NeuronMapping)this.inputsGetter.visit(neuralState);
            inputs = new ArrayList();
            visit.iterator().forEachRemaining(inputs::add);
        } else {
            inputs = neuron.getInputs();
        }
        ArrayList<Object> maskedInputs = new ArrayList<Object>(inputMask.length);
        for (int i = 0; i < inputMask.length; ++i) {
            int i1 = inputMask[i];
            maskedInputs.add(inputs.get(i1));
        }
        return maskedInputs.iterator();
    }

    public <T extends Neurons, S extends State.Neural> Iterator<T> getInputs(Neurons<T, S> neuron) {
        if (neuron.isShared()) {
            N neuralState = this.getState(neuron);
            NeuronMapping visit = (NeuronMapping)this.inputsGetter.visit(neuralState);
            return visit.iterator();
        }
        return neuron.getInputs().iterator();
    }

    public <T extends Neurons, S extends State.Neural> Iterator<T> getInputs(BaseNeuron<T, S> neuron) {
        if (neuron.isShared) {
            N neuralState = this.getState(neuron);
            NeuronMapping visit = (NeuronMapping)this.inputsGetter.visit(neuralState);
            return visit.iterator();
        }
        return neuron.getInputs().iterator();
    }

    public <T extends Neurons, S extends State.Neural> Iterator<Neurons> getOutputs(BaseNeuron<T, S> neuron) {
        N neuralState = this.getState(neuron);
        NeuronMapping visit = (NeuronMapping)this.outputsGetter.visit(neuralState);
        return visit.iterator();
    }

    @Deprecated
    public <V> NeuronIterating getPreferredBUpIterator(NeuronVisitor.Weighted vNeuronVisitor, BaseNeuron<Neurons, State.Neural> outputNeuron) {
        DFSstack dFSstack = new DFSstack();
        dFSstack.getClass();
        return new DFSstack.BUpIterator(dFSstack, this, outputNeuron, vNeuronVisitor);
    }

    @Deprecated
    public <V> NeuronIterating getPreferredTDownIterator(NeuronVisitor.Weighted vNeuronVisitor, BaseNeuron<Neurons, State.Neural> outputNeuron) {
        DFSstack dFSstack = new DFSstack();
        dFSstack.getClass();
        return new DFSstack.TDownIterator(dFSstack, this, outputNeuron, vNeuronVisitor);
    }

    public String toString() {
        return "net:" + this.id + ", neurons: " + this.neuronCount;
    }
}

