/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training;

import constructs.template.Template;
import java.io.Reader;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import learning.Model;
import networks.computation.evaluation.values.Value;
import networks.computation.evaluation.values.distributions.ValueInitializer;
import networks.structure.components.neurons.QueryNeuron;
import networks.structure.components.weights.Weight;
import settings.Settings;

public class NeuralModel
implements Model<QueryNeuron> {
    private static final Logger LOG = Logger.getLogger(NeuralModel.class.getName());
    public List<Weight> weights;
    private transient Settings settings;
    public Value threshold;
    private Template template;

    public NeuralModel(List<Weight> weights, Settings settings) {
        this.settings = settings;
        this.weights = weights;
        if (settings.optimizer == Settings.OptimizerSet.ADAM) {
            this.init4Adam(weights);
        }
    }

    public NeuralModel(Template template, Settings settings) {
        this(template.getAllWeights(), settings);
        if (settings.debugTemplateTraining) {
            this.template = template;
        }
    }

    private NeuralModel(Settings settings) {
        this.settings = settings;
    }

    protected void init4Adam(List<Weight> weights) {
        for (Weight weight : weights) {
            weight.velocity = weight.value.getForm();
            weight.momentum = weight.value.getForm();
        }
    }

    public NeuralModel cloneValues() {
        List clonedWeights = this.weights.stream().map(Weight::clone).collect(Collectors.toList());
        NeuralModel clone = new NeuralModel(this.settings);
        clone.weights = clonedWeights;
        clone.template = this.template;
        return clone;
    }

    public void resetWeights(ValueInitializer valueInitializer) {
        for (Weight weight : this.weights) {
            weight.init(valueInitializer);
        }
    }

    public void loadWeightValues(NeuralModel otherModel) {
        Map<Integer, Weight> otherWeights = otherModel.mapWeightsToIds();
        for (Weight weight : this.weights) {
            weight.value = otherWeights.get((Object)Integer.valueOf((int)weight.index)).value;
        }
    }

    public void dropoutWeights() {
    }

    public List<Weight> filterLearnable(List<Weight> allWeights) {
        return allWeights.stream().filter(Weight::isLearnable).collect(Collectors.toList());
    }

    public Map<Integer, Weight> mapWeightsToIds() {
        return this.weights.stream().collect(Collectors.toMap(w -> w.index, w -> w));
    }

    public Map<String, Weight> mapWeightsToNames() {
        return this.weights.stream().collect(Collectors.toMap(w -> w.name, w -> w));
    }

    public void importWeights(Reader tensorflow, Map<String, Weight> mapping) {
    }

    @Override
    public String getId() {
        return null;
    }

    @Override
    public Value evaluate(QueryNeuron query) {
        return null;
    }

    @Override
    public List<Weight> getAllWeights() {
        return this.weights;
    }

    public Template getTemplate() {
        if (this.template == null) {
            LOG.severe("No template was stored in this NeuralModel");
        }
        return this.template;
    }
}

