/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.components.neurons.states;

import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.StateVisiting;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.networks.ParentsTransfer;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Backproper;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Dropouter;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Evaluator;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Invalidator;
import cz.cvut.fel.ida.neural.networks.structure.building.builders.StatesBuilder;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.AggregationState;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.NeuronMapping;
import cz.cvut.fel.ida.neural.networks.structure.metadata.inputMappings.WeightedNeuronMapping;
import cz.cvut.fel.ida.setup.Settings;
import java.util.logging.Logger;

public abstract class States
implements State {
    private static final Logger LOG = Logger.getLogger(States.class.getName());

    public static class NetworkParents
    implements State.Structure<Value>,
    State.Structure.Parents {
        int parentCount;
        State.Neural<Value> parentCounter;

        public Value accept(ParentsTransfer visitor) {
            visitor.parentsCount = this.parentCount;
            this.parentCounter.accept(visitor);
            return null;
        }

        public NetworkParents(State.Neural<Value> parentCounter, int parentCount) {
            this.parentCounter = parentCounter;
            this.parentCount = parentCount;
        }

        @Override
        public int getParentCount() {
            return this.parentCount;
        }

        @Override
        public void setParentCount(int parentCount) {
            this.parentCount = parentCount;
        }

        @Override
        public State.Neural<Value> getParentCounter() {
            return this.parentCounter;
        }

        @Override
        public void invalidate() {
        }

        public class WeightedInputsParents
        extends WeightedInputs
        implements State.Structure.Parents {
            public WeightedInputsParents(WeightedNeuronMapping<Neurons> inputs) {
                super(inputs);
            }

            @Override
            public WeightedNeuronMapping<Neurons> getWeightedMapping() {
                return this.inputs;
            }

            @Override
            public void invalidate() {
            }

            @Override
            public int getParentCount() {
                return NetworkParents.this.parentCount;
            }

            @Override
            public void setParentCount(int parentCount) {
                NetworkParents.this.parentCount = parentCount;
            }

            @Override
            public State.Neural<Value> getParentCounter() {
                return NetworkParents.this.parentCounter;
            }
        }

        public class InputsParents
        extends Inputs
        implements State.Structure.Parents {
            public InputsParents(NeuronMapping<Neurons> inputs) {
                super(inputs);
            }

            @Override
            public NeuronMapping<Neurons> getInputMapping() {
                return this.inputs;
            }

            @Override
            public void invalidate() {
            }

            @Override
            public int getParentCount() {
                return NetworkParents.this.parentCount;
            }

            @Override
            public void setParentCount(int parentCount) {
                NetworkParents.this.parentCount = parentCount;
            }

            @Override
            public State.Neural<Value> getParentCounter() {
                return NetworkParents.this.parentCounter;
            }
        }
    }

    public static class Outputs
    implements State.Structure.OutputNeuronMap {
        NeuronMapping<Neurons> outputs;

        @Override
        public NeuronMapping<Neurons> getOutputMapping() {
            return this.outputs;
        }

        @Override
        public void invalidate() {
        }
    }

    public static class WeightedInputs
    implements State.Structure.WeightedInputsMap {
        WeightedNeuronMapping<Neurons> inputs;

        public WeightedInputs(WeightedNeuronMapping<Neurons> inputs) {
            this.inputs = inputs;
        }

        @Override
        public WeightedNeuronMapping<Neurons> getWeightedMapping() {
            return this.inputs;
        }

        @Override
        public void invalidate() {
        }
    }

    public static class Inputs
    implements State.Structure.InputNeuronMap {
        NeuronMapping<Neurons> inputs;

        public Inputs(NeuronMapping<Neurons> inputs) {
            this.inputs = inputs;
        }

        @Override
        public NeuronMapping<Neurons> getInputMapping() {
            return this.inputs;
        }

        @Override
        public void invalidate() {
        }
    }

    public static class SimpleValue
    implements State.Neural.Computation {
        Value value;

        public SimpleValue(Value factValue) {
            this.value = factValue;
        }

        @Override
        public void invalidate() {
        }

        @Override
        public State.Neural.Computation clone() {
            return new SimpleValue(this.value.clone());
        }

        @Override
        public void setupValueDimensions(Value value) {
            this.value = value.getForm();
        }

        @Override
        public AggregationState getAggregationState() {
            LOG.severe("Fact neurons cannot be evaluated, you can only obtain the value via getResult!");
            return null;
        }

        @Override
        public Value getValue() {
            return this.value;
        }

        @Override
        public Value getGradient() {
            return null;
        }

        @Override
        public void setValue(Value value) {
        }

        @Override
        public void setGradient(Value gradient) {
        }

        @Override
        public void storeValue(Value value) {
        }

        @Override
        public void storeGradient(Value gradient) {
        }

        @Override
        public Aggregation getAggregation() {
            LOG.warning("FactNeurons have no aggregation.");
            return null;
        }
    }

    public static final class DropoutStore
    extends ComputationStateStandard
    implements State.Neural.Computation.HasDropout {
        public double dropoutRate;
        public boolean isDropped;
        private boolean dropoutProcessed;
        private Settings settings;

        public DropoutStore(Settings settings, double dropoutRate, Aggregation activationFunction) {
            super(activationFunction);
            this.settings = settings;
            this.dropoutRate = dropoutRate;
        }

        public DropoutStore(Settings settings, Aggregation activationFunction) {
            super(activationFunction);
            this.settings = settings;
            this.dropoutRate = settings.dropoutRate;
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.isDropped = false;
            this.dropoutProcessed = false;
        }

        @Override
        public DropoutStore clone() {
            DropoutStore clone = (DropoutStore)super.clone();
            clone.dropoutRate = this.dropoutRate;
            clone.isDropped = this.isDropped;
            clone.dropoutProcessed = this.dropoutProcessed;
            clone.settings = this.settings;
            return clone;
        }

        public boolean ready4expansion(Dropouter visitor) {
            return !this.dropoutProcessed;
        }

        @Override
        public double getDropoutRate(StateVisiting visitor) {
            return this.dropoutRate;
        }

        @Override
        public void setDropoutRate(double rate) {
            this.dropoutRate = rate;
        }

        @Override
        public void setDropout(StateVisiting visitor) {
            this.isDropped = this.settings.random.nextDouble() < this.settings.dropoutRate;
            this.dropoutProcessed = true;
        }

        public final class ParentsDropoutStore
        extends ParentCounter
        implements State.Neural.Computation.HasDropout {
            public ParentsDropoutStore(Settings settings, Aggregation activationFunction) {
                super(activationFunction);
                DropoutStore.this.settings = settings;
            }

            public ParentsDropoutStore(Aggregation activationFunction) {
                super(activationFunction);
            }

            public ParentsDropoutStore(Settings settings, double dropoutRate, Aggregation aggregation) {
                super(aggregation);
                DropoutStore.this.settings = settings;
                DropoutStore.this.dropoutRate = dropoutRate;
            }

            @Override
            public ParentsDropoutStore clone() {
                ParentsDropoutStore clone = new ParentsDropoutStore(DropoutStore.this.settings, DropoutStore.this.dropoutRate, this.aggregationState.getAggregation());
                clone.parentCount = this.parentCount;
                clone.checked = this.checked;
                clone.calculated = this.calculated;
                return clone;
            }

            @Override
            public double getDropoutRate(StateVisiting visitor) {
                return DropoutStore.this.getDropoutRate(visitor);
            }

            @Override
            public void setDropoutRate(double rate) {
                DropoutStore.this.dropoutRate = rate;
            }

            @Override
            public void setDropout(StateVisiting visitor) {
                DropoutStore.this.setDropout(visitor);
            }

            @Override
            public void setParents(StateVisiting visitor, int parentCount) {
                this.parentCount = parentCount;
            }
        }
    }

    public static class ParentCounter
    extends ComputationStateStandard
    implements State.Neural.Computation.HasParents {
        public int parentCount;
        public int checked = 0;
        boolean calculated;

        public ParentCounter(Aggregation activationFunction, int count) {
            super(activationFunction);
            this.parentCount = count;
        }

        public ParentCounter(Aggregation activationFunction) {
            super(activationFunction);
        }

        @Override
        public void invalidate() {
            super.invalidate();
            this.checked = 0;
            this.calculated = false;
        }

        @Override
        public ParentCounter clone() {
            ParentCounter clone = (ParentCounter)super.clone();
            clone.parentCount = this.parentCount;
            clone.checked = this.checked;
            clone.calculated = this.calculated;
            return clone;
        }

        @Override
        public void storeGradient(Value gradient) {
            super.storeGradient(gradient);
            ++this.checked;
        }

        @Override
        public boolean ready4expansion(StateVisiting visitor) {
            if (visitor instanceof Backproper) {
                return this.ready4expansion((Backproper)visitor);
            }
            if (visitor instanceof Evaluator) {
                return this.ready4expansion((Evaluator)visitor);
            }
            if (visitor instanceof Invalidator) {
                return this.ready4expansion((Invalidator)visitor);
            }
            return true;
        }

        public boolean ready4expansion(Backproper visitor) {
            return this.checked == this.parentCount;
        }

        public boolean ready4expansion(Evaluator visitor) {
            return this.calculated;
        }

        public boolean ready4expansion(Invalidator visitor) {
            return true;
        }

        @Override
        public int getParents(StateVisiting visitor) {
            return this.parentCount;
        }

        @Override
        public int getChecked(StateVisiting visitor) {
            return this.checked;
        }

        @Override
        public void setChecked(StateVisiting visitor, int checked) {
            this.checked = checked;
        }

        @Override
        public void setParents(StateVisiting visitor, int parentCount) {
            this.parentCount = parentCount;
        }

        @Override
        public void setValue(Value value) {
            super.setValue(value);
            this.calculated = true;
        }
    }

    public static class ComputationStateStandard
    implements State.Neural.Computation {
        public AggregationState aggregationState;
        public Value outputValue;
        public Value acumGradient;

        public ComputationStateStandard(Aggregation activation) {
            this.aggregationState = StatesBuilder.getAggregationState(activation);
        }

        @Override
        public void invalidate() {
            this.outputValue.zero();
            this.acumGradient.zero();
            this.aggregationState.invalidate();
        }

        @Override
        public Aggregation getAggregation() {
            return this.aggregationState.getAggregation();
        }

        @Override
        public ComputationStateStandard clone() {
            ComputationStateStandard clone = new ComputationStateStandard(this.aggregationState.getAggregation());
            clone.outputValue = this.outputValue.clone();
            clone.acumGradient = this.acumGradient.clone();
            return clone;
        }

        @Override
        public void setupValueDimensions(Value value) {
            if (this.acumGradient != null) {
                if (value.size().equals(this.acumGradient.size())) {
                    LOG.severe("Collision with previously inferred Value dimensions!");
                }
                return;
            }
            this.aggregationState.setupValueDimensions(value);
            this.outputValue = value.getForm();
            this.acumGradient = value.getForm();
        }

        @Override
        public AggregationState getAggregationState() {
            return this.aggregationState;
        }

        @Override
        public void setValue(Value value) {
            this.outputValue = value;
        }

        @Override
        public void setGradient(Value gradient) {
            this.acumGradient = gradient;
        }

        @Override
        public Value getValue() {
            return this.outputValue;
        }

        @Override
        public Value getGradient() {
            return this.acumGradient;
        }

        @Override
        public void storeValue(Value value) {
            this.aggregationState.cumulate(value);
        }

        @Override
        public void storeGradient(Value value) {
            this.acumGradient.incrementBy(value);
        }
    }

    public static final class ComputationStateComposite<T extends State.Neural.Computation>
    implements State.Neural<Value> {
        public final T[] states;
        Aggregation aggregation;

        public ComputationStateComposite(T[] states) {
            this.states = states;
        }

        @Override
        public State.Neural.Computation getComputationView(int index) {
            return this.states[index];
        }

        @Override
        public Aggregation getAggregation() {
            return this.aggregation;
        }

        public Value accept(StateVisiting.Computation visitor) {
            return this.states[visitor.stateIndex].accept(visitor);
        }

        @Override
        public void invalidate() {
            for (int i = 0; i < this.states.length; ++i) {
                this.states[i].invalidate();
            }
        }
    }
}

