/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.structure.components.neurons.states;

import cz.cvut.fel.ida.algebra.functions.Activation;
import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

public abstract class AggregationState
implements Aggregation.State {
    private static final Logger LOG = Logger.getLogger(AggregationState.class.getName());

    public abstract Aggregation getAggregation();

    public abstract void setupValueDimensions(Value var1);

    public static class CrossProducState
    extends CumulationState {
        private static Map<Mapping, Mapping> cache = new HashMap<Mapping, Mapping>();
        public int[][] mapping;
        int cross = 0;

        public CrossProducState(Aggregation aggregation) {
            super(aggregation);
        }

        public void initMapping(List<Value> inputValues) {
            int cross = 1;
            int[] sizes = new int[inputValues.size()];
            for (int i = 0; i < inputValues.size(); ++i) {
                Value value = inputValues.get(i);
                int oneSize = 1;
                int[] size = value.size();
                for (int j = 0; j < size.length; ++j) {
                    oneSize *= size[j];
                }
                sizes[i] = oneSize;
                cross *= oneSize;
            }
            this.mapping = new int[cross][inputValues.size()];
            this.combinations(0, new int[sizes.length], sizes);
            Mapping wrap = new Mapping(this.mapping);
            Mapping load = cache.get(wrap);
            if (load == null) {
                cache.put(wrap, wrap);
            }
        }

        private void combinations(int input, int[] current, int[] sizes) {
            if (input == sizes.length) {
                System.arraycopy(current, 0, this.mapping[this.cross], 0, sizes.length);
                ++this.cross;
                return;
            }
            int i = 0;
            while (i < sizes[input]) {
                current[input] = i++;
                this.combinations(input + 1, current, sizes);
            }
        }

        public static class Mapping {
            int[][] mapping;
            int hashcode = -1;

            public Mapping(int[][] mapping) {
                this.mapping = mapping;
            }

            public int hashCode() {
                if (this.hashcode != -1) {
                    return this.hashcode;
                }
                this.hashcode = Arrays.deepHashCode((Object[])this.mapping);
                return this.hashcode;
            }

            public boolean equals(Object obj) {
                return Arrays.deepEquals((Object[])this.mapping, (Object[])((Mapping)obj).mapping);
            }
        }
    }

    public static class CumulationState
    extends AggregationState {
        Aggregation aggregation;
        List<Value> accumulatedInputs;

        public CumulationState(Aggregation aggregation) {
            this.aggregation = aggregation;
            this.accumulatedInputs = new ArrayList<Value>();
        }

        @Override
        public Aggregation getAggregation() {
            return this.aggregation;
        }

        @Override
        public void setupValueDimensions(Value value) {
        }

        @Override
        public void cumulate(Value value) {
            this.accumulatedInputs.add(value);
        }

        @Override
        public void invalidate() {
            this.accumulatedInputs.clear();
        }

        @Override
        public int[] getInputMask() {
            return null;
        }

        @Override
        public Value gradient() {
            return this.aggregation.differentiate(this.accumulatedInputs);
        }

        @Override
        public Value evaluate() {
            return this.aggregation.evaluate(this.accumulatedInputs);
        }
    }

    public static abstract class Pooling
    extends AggregationState {
        Aggregation aggregation;

        public Pooling(Aggregation aggregation) {
            this.aggregation = aggregation;
        }

        @Override
        public Aggregation getAggregation() {
            return this.aggregation;
        }

        public static class MaxK
        extends Pooling {
            public MaxK(Aggregation aggregation, int k) {
                super(aggregation);
            }

            @Override
            public void cumulate(Value value) {
            }

            @Override
            public void invalidate() {
            }

            @Override
            public int[] getInputMask() {
                return new int[0];
            }

            @Override
            public Value gradient() {
                return null;
            }

            @Override
            public Value evaluate() {
                return null;
            }

            @Override
            public void setupValueDimensions(Value value) {
            }
        }

        public static class Sum
        extends Pooling {
            Value sum;

            public Sum(Aggregation aggregation) {
                super(aggregation);
            }

            public Sum(Aggregation aggregation, Value initSum) {
                super(aggregation);
                this.sum = initSum;
            }

            @Override
            public void cumulate(Value value) {
                this.sum.incrementBy(value);
            }

            @Override
            public void invalidate() {
                this.sum.zero();
            }

            @Override
            public int[] getInputMask() {
                return null;
            }

            @Override
            public Value gradient() {
                return new ScalarValue(1.0);
            }

            @Override
            public Value evaluate() {
                return this.sum;
            }

            @Override
            public void setupValueDimensions(Value value) {
                this.sum = value.getForm();
            }
        }

        public static class Avg
        extends Pooling {
            int count = 0;
            Value sum;

            public Avg(Aggregation aggregation) {
                super(aggregation);
            }

            public Avg(Aggregation aggregation, Value initSum) {
                super(aggregation);
                this.sum = initSum;
            }

            @Override
            public void cumulate(Value value) {
                this.sum.incrementBy(value);
                ++this.count;
            }

            @Override
            public void invalidate() {
                this.count = 0;
                this.sum.zero();
            }

            @Override
            public int[] getInputMask() {
                return null;
            }

            @Override
            public Value gradient() {
                return new ScalarValue(1.0 / (double)this.count);
            }

            @Override
            public Value evaluate() {
                return this.sum.times((Value)new ScalarValue(1.0 / (double)this.count));
            }

            @Override
            public void setupValueDimensions(Value value) {
                this.sum = value.getForm();
            }
        }

        public static class Max
        extends Pooling {
            int maxIndex = -1;
            int currentIndex = 0;
            Value maxValue;

            public Max(Aggregation aggregation) {
                super(aggregation);
            }

            @Override
            public void cumulate(Value value) {
                if (this.maxValue == null || value.greaterThan(this.maxValue)) {
                    this.maxValue = value;
                    this.maxIndex = this.currentIndex;
                }
                ++this.currentIndex;
            }

            @Override
            public void invalidate() {
                this.maxIndex = -1;
                this.currentIndex = 0;
                this.maxValue = null;
            }

            @Override
            public int[] getInputMask() {
                int[] inputs = new int[]{this.maxIndex};
                return inputs;
            }

            @Override
            public Value gradient() {
                return new ScalarValue(1.0);
            }

            @Override
            public Value evaluate() {
                return this.maxValue;
            }

            @Override
            public void setupValueDimensions(Value value) {
                this.maxValue = value.getForm();
            }
        }
    }

    public static class ActivationState
    extends AggregationState {
        Activation activation;
        Value summedInputs;

        public ActivationState(Activation activation) {
            this.activation = activation;
        }

        public ActivationState(Activation activation, Value valueStore) {
            this.activation = activation;
            this.summedInputs = valueStore;
        }

        @Override
        public void cumulate(Value value) {
            this.summedInputs.incrementBy(value);
        }

        @Override
        public void invalidate() {
            this.summedInputs.zero();
        }

        @Override
        public int[] getInputMask() {
            return null;
        }

        @Override
        public Value gradient() {
            return this.activation.differentiate(this.summedInputs);
        }

        @Override
        public Value evaluate() {
            return this.activation.evaluate(this.summedInputs);
        }

        @Override
        public Activation getAggregation() {
            return this.activation;
        }

        @Override
        public void setupValueDimensions(Value value) {
            this.summedInputs = value.getForm();
        }
    }
}

