/*
 * Decompiled with CFR 0.152.
 */
package networks.structure.building.builders;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;
import networks.computation.evaluation.functions.CrossProduct;
import networks.computation.evaluation.values.Value;
import networks.computation.training.strategies.Hyperparameters.DropoutRateStrategy;
import networks.structure.components.neurons.BaseNeuron;
import networks.structure.components.neurons.Neurons;
import networks.structure.components.neurons.WeightedNeuron;
import networks.structure.components.neurons.types.FactNeuron;
import networks.structure.components.types.DetailedNetwork;
import networks.structure.components.weights.Weight;
import networks.structure.metadata.inputMappings.NeuronMapping;
import networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import networks.structure.metadata.states.AggregationState;
import networks.structure.metadata.states.State;
import networks.structure.metadata.states.States;
import networks.structure.metadata.states.StatesCache;
import settings.Settings;
import utils.generic.Pair;

public class StatesBuilder {
    private static final Logger LOG = Logger.getLogger(StatesBuilder.class.getName());
    private Settings settings;

    public StatesBuilder(Settings settings) {
        this.settings = settings;
    }

    public boolean makeParallel(BaseNeuron neuron) {
        State.Neural.Computation state = neuron.getComputationView(0);
        if (this.settings.parallelTraining && !(neuron.getRawState() instanceof States.ComputationStateComposite)) {
            States.ComputationStateComposite<State.Neural.Computation> compositeState = State.createCompositeState(state, this.settings.minibatchSize);
            neuron.setState(compositeState);
            return true;
        }
        return false;
    }

    void addLinkedInputsToNetworkStates(DetailedNetwork<State.Structure> neuralNetwork) {
        neuralNetwork.extraInputMapping.forEach((neuron, inputs) -> {
            if (inputs instanceof NeuronMapping) {
                States.Inputs inputsState = new States.Inputs((NeuronMapping<Neurons>)inputs);
                neuralNetwork.addState((Neurons)neuron, inputsState);
            } else if (inputs instanceof WeightedNeuronMapping) {
                States.WeightedInputs weightedInputsState = new States.WeightedInputs((WeightedNeuronMapping)inputs);
                neuralNetwork.addState((Neurons)neuron, weightedInputsState);
            }
        });
    }

    public void inferValues(DetailedNetwork<State.Structure> detailedNetwork) {
        for (int i = 0; i < detailedNetwork.allNeuronsTopologic.size(); ++i) {
            BaseNeuron neuron = (BaseNeuron)detailedNetwork.allNeuronsTopologic.get(i);
            if (neuron.getComputationView(0).getValue() != null) continue;
            if (neuron instanceof WeightedNeuron) {
                this.inferWeightedDimension(detailedNetwork, neuron);
                continue;
            }
            this.inferUnweightedDimension(detailedNetwork, neuron);
        }
    }

    private void inferWeightedDimension(DetailedNetwork<State.Structure> detailedNetwork, BaseNeuron<Neurons, State.Neural> neuron) {
        if (neuron instanceof FactNeuron) {
            return;
        }
        WeightedNeuron weightedNeuron = (WeightedNeuron)neuron;
        ArrayList<Value> inputValues = new ArrayList<Value>();
        Pair inputs = detailedNetwork.getInputs(weightedNeuron);
        Iterator neuronIterator = (Iterator)inputs.r;
        Iterator weightIterator = (Iterator)inputs.s;
        Value value = null;
        Value weight = null;
        if (neuronIterator.hasNext()) {
            BaseNeuron next = (BaseNeuron)neuronIterator.next();
            State.Neural.Computation computationView = next.getComputationView(0);
            value = computationView.getValue();
            Weight nextWeight = (Weight)weightIterator.next();
            if (nextWeight == null) {
                LOG.finer("Weight for input missing, deducing unit weight for: " + next.name);
                weight = Value.ONE;
            } else {
                weight = nextWeight.value;
            }
        }
        if (value == null || weight == null) {
            LOG.warning("Value dimension cannot be inferred for " + neuron);
            return;
        }
        Value sum = weight.times(value);
        if (sum == null) {
            LOG.severe("Weight-Value dimension mismatch at neuron:" + neuron);
        }
        inputValues.add(value);
        while (neuronIterator.hasNext()) {
            value = ((BaseNeuron)neuronIterator.next()).getComputationView(0).getValue();
            weight = ((Weight)weightIterator.next()).value;
            if (value == null || weight == null) {
                LOG.severe("Value dimension cannot be inferred!" + neuron);
                continue;
            }
            Value increment = weight.times(value);
            if (increment == null) {
                LOG.severe("Weight-Value dimension mismatch at neuron:" + neuron);
                continue;
            }
            if (neuron.getAggregation() instanceof CrossProduct) {
                inputValues.add(increment);
                continue;
            }
            Value plus = sum.plus(increment);
            if (plus == null) {
                LOG.severe("Input Values dimension mismatch at neuron:" + neuron);
                continue;
            }
            sum.incrementBy(increment);
        }
        if (neuron.getAggregation() instanceof CrossProduct) {
            sum = neuron.getAggregation().evaluate(inputValues);
            AggregationState.CrossProducState crossProducState = (AggregationState.CrossProducState)neuron.getComputationView(0).getAggregationState();
            crossProducState.initMapping(inputValues);
        }
        neuron.getComputationView(0).setupValueDimensions(sum);
    }

    private void inferUnweightedDimension(DetailedNetwork<State.Structure> detailedNetwork, BaseNeuron<Neurons, State.Neural> neuron) {
        Iterator<Neurons> inputs = detailedNetwork.getInputs(neuron);
        ArrayList<Value> inputValues = new ArrayList<Value>();
        Value sum = inputs.next().getComputationView(0).getValue();
        if (sum == null) {
            LOG.severe("Value dimension cannot be inferred!" + neuron);
        } else {
            inputValues.add(sum);
            sum = sum.clone();
        }
        while (inputs.hasNext()) {
            Neurons next = inputs.next();
            Value result = next.getComputationView(0).getValue();
            if (result == null) {
                LOG.severe("Value dimension cannot be inferred!" + neuron);
                continue;
            }
            if (neuron.getAggregation() instanceof CrossProduct) {
                inputValues.add(result);
                continue;
            }
            Value increment = sum.plus(result);
            if (increment == null) {
                LOG.severe("Input Values dimension mismatch at neuron:" + neuron);
                continue;
            }
            sum.incrementBy(result);
        }
        if (neuron.getAggregation() instanceof CrossProduct) {
            sum = neuron.getAggregation().evaluate(inputValues);
            AggregationState.CrossProducState crossProducState = (AggregationState.CrossProducState)neuron.getComputationView(0).getAggregationState();
            crossProducState.initMapping(inputValues);
        }
        neuron.getComputationView(0).setupValueDimensions(sum);
    }

    void setupDropoutStates(DetailedNetwork<State.Structure> detailedNetwork) {
        DropoutRateStrategy dropoutRateStrategy = new DropoutRateStrategy(this.settings);
        for (int i = detailedNetwork.allNeuronsTopologic.size() - 1; i > 0; --i) {
            BaseNeuron neuron = (BaseNeuron)detailedNetwork.allNeuronsTopologic.get(i);
            if (neuron.layer == 0) {
                neuron.layer = 1;
            }
            dropoutRateStrategy.setDropout(neuron);
            Iterator inputs = detailedNetwork.getInputs(neuron);
            while (inputs.hasNext()) {
                Neurons next = (Neurons)inputs.next();
                if (next.getLayer() >= neuron.layer + 1) continue;
                next.setLayer(neuron.layer + 1);
            }
        }
    }

    int makeSharedStatesRecursively(DetailedNetwork<State.Structure> detailedNetwork) {
        int sharedCount = 0;
        for (int i = detailedNetwork.allNeuronsTopologic.size() - 1; i > 0; --i) {
            BaseNeuron neuron = (BaseNeuron)detailedNetwork.allNeuronsTopologic.get(i);
            if (!neuron.isShared) continue;
            ++sharedCount;
            this.makeParallel(neuron);
            Iterator inputs = detailedNetwork.getInputs(neuron);
            while (inputs.hasNext()) {
                ((Neurons)inputs.next()).setShared(true);
            }
        }
        return sharedCount;
    }

    void setupParentStateNumbers(DetailedNetwork<State.Structure> network) {
        Map<BaseNeuron, NeuronMapping<Neurons>> neuronOutputs = network.outputMapping;
        neuronOutputs.forEach((neuron, outputs) -> {
            State.Neural.Computation state = neuron.getComputationView(0);
            if (state instanceof State.Neural.Computation.HasParents) {
                State.Neural.Computation.HasParents parentsState = (State.Neural.Computation.HasParents)((Object)state);
                int parents = parentsState.getParents(null);
                if (parents != 0 && parents != outputs.getLastList().size()) {
                    neuron.setShared(true);
                    neuron.sharedAfterCreation = true;
                    if (this.settings.parallelTraining) {
                        boolean bl = this.makeParallel((BaseNeuron)neuron);
                    }
                    Object rawState = neuron.getRawState();
                    States.NetworkParents networkParents = new States.NetworkParents((State.Neural<Value>)rawState, outputs.getLastList().size());
                    network.addState((Neurons)neuron, networkParents);
                } else if (parents == 0) {
                    parentsState.setParents(null, outputs.getLastList().size());
                }
            }
        });
    }

    private State.Structure createFinalState(List<State.Structure> structures) {
        if (structures.size() == 1) {
            return structures.get(0);
        }
        if (structures.isEmpty()) {
            return null;
        }
        State.Structure.Parents result = null;
        boolean parents = false;
        boolean inputs = false;
        boolean weightedInputs = false;
        boolean outputs = false;
        State.Structure.Parents hasParents = null;
        State.Structure.InputNeuronMap inputNeuronMap = null;
        State.Structure.WeightedInputsMap weightedInputsMap = null;
        for (State.Structure structure : structures) {
            if (structure instanceof State.Structure.Parents) {
                parents = true;
                hasParents = (State.Structure.Parents)((Object)structure);
                continue;
            }
            if (structure instanceof State.Structure.InputNeuronMap) {
                inputs = true;
                inputNeuronMap = (State.Structure.InputNeuronMap)structure;
                continue;
            }
            if (structure instanceof State.Structure.WeightedInputsMap) {
                weightedInputs = true;
                weightedInputsMap = (State.Structure.WeightedInputsMap)structure;
                continue;
            }
            if (!(structure instanceof State.Structure.OutputNeuronMap)) continue;
            outputs = true;
        }
        if (parents && inputs && !weightedInputs && !outputs) {
            States.NetworkParents networkParents = new States.NetworkParents(hasParents.getParentCounter(), hasParents.getParentCount());
            networkParents.getClass();
            result = new States.NetworkParents.InputsParents(networkParents, inputNeuronMap.getInputMapping());
        } else if (parents && !inputs && weightedInputs && !outputs) {
            States.NetworkParents networkParents = new States.NetworkParents(hasParents.getParentCounter(), hasParents.getParentCount());
            networkParents.getClass();
            result = new States.NetworkParents.WeightedInputsParents(networkParents, weightedInputsMap.getWeightedMapping());
        }
        return result;
    }

    public DetailedNetwork<State.Structure> setupFinalStatesCache(DetailedNetwork<State.Structure> neuralNetwork) {
        State.Structure[] structureStates;
        Map<Neurons, List<State.Structure>> cumulativeStates = neuralNetwork.cumulativeStates;
        if (cumulativeStates.isEmpty()) {
            return neuralNetwork;
        }
        if (this.settings.iterationMode == Settings.IterationMode.TOPOLOGIC) {
            structureStates = new State.Structure[neuralNetwork.allNeuronsTopologic.size()];
            for (int i = 0; i < neuralNetwork.allNeuronsTopologic.size(); ++i) {
                State.Structure finalState2;
                BaseNeuron neuron2 = (BaseNeuron)neuralNetwork.allNeuronsTopologic.get(i);
                List<State.Structure> structures2 = cumulativeStates.get(neuron2);
                if (structures2 == null) continue;
                structureStates[i] = finalState2 = this.createFinalState(structures2);
            }
        } else {
            structureStates = new State.Structure[cumulativeStates.size()];
            TreeMap<Integer, State.Structure> finalStates = new TreeMap<Integer, State.Structure>();
            cumulativeStates.forEach((neuron, structures) -> {
                State.Structure finalState = this.createFinalState((List<State.Structure>)structures);
                finalStates.put(neuron.getIndex(), finalState);
            });
            finalStates.forEach((index, finalState) -> {
                structureStates[index.intValue()] = finalState;
            });
        }
        neuralNetwork.neuronStates = StatesCache.getCache((Settings)this.settings, (State.Structure[])structureStates);
        return neuralNetwork;
    }
}

