/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.pipelines.debugging.drawing;

import cz.cvut.fel.ida.algebra.functions.Aggregation;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.drawing.Drawer;
import cz.cvut.fel.ida.drawing.GraphViz;
import cz.cvut.fel.ida.neural.networks.computation.iteration.modes.Topologic;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.neurons.NeuronVisitor;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.Evaluator;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.structure.components.NeuralNetwork;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.BaseNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.Neurons;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.QueryNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.WeightedNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.states.State;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AggregationNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.AtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.FactNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.RuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedAtomNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.neurons.types.WeightedRuleNeuron;
import cz.cvut.fel.ida.neural.networks.structure.components.types.TopologicNetwork;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

public class NeuralNetDrawer
extends Drawer<NeuralSample> {
    private static final Logger LOG = Logger.getLogger(NeuralNetDrawer.class.getName());
    private NeuronDrawer neuronDrawer;
    private Topologic.BUpIterator bUpIterator;
    public int stateIndex = -1;

    public NeuralNetDrawer(Settings settings) {
        super(settings);
    }

    @Override
    public void loadGraph(NeuralSample sample) {
        this.neuronDrawer = new NeuronDrawer((NeuralNetwork<State.Structure>)((NeuralNetwork)((QueryNeuron)sample.query).evidence), this.stateIndex, this.graphviz);
        Topologic topologic = new Topologic((TopologicNetwork)((QueryNeuron)sample.query).evidence);
        topologic.getClass();
        this.bUpIterator = new Topologic.BUpIterator(topologic, (Neurons)((QueryNeuron)sample.query).neuron, (NeuronVisitor.Weighted)this.neuronDrawer);
        this.graphviz.start_graph();
        this.iterateNetwork();
        this.graphviz.addln(((QueryNeuron)sample.query).neuron.getIndex() + " [shape = tripleoctagon]");
        this.graphviz.end_graph();
    }

    private void iterateNetwork() {
        while (this.bUpIterator.hasNext()) {
            Object nextNeuron = this.bUpIterator.next();
            nextNeuron.visit(this.neuronDrawer);
        }
    }

    public static class NeuronDrawer
    extends NeuronVisitor.Weighted.Detailed {
        private final GraphViz gv;
        private final NeuralNetwork<State.Structure> neuralNetwork;

        public NeuronDrawer(NeuralNetwork<State.Structure> network, int stateIndex, GraphViz gv) {
            super(network, new Evaluator(stateIndex), null);
            this.gv = gv;
            this.neuralNetwork = network;
        }

        private String getNeuronLabel(BaseNeuron neuron) {
            String name = neuron.getClass().getSimpleName() + ":" + neuron.index + ":" + neuron.name;
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            String value = state.getValue().toString();
            Value stateGradient = state.getGradient();
            String gradient = "";
            if (stateGradient != null) {
                gradient = stateGradient.toString();
            }
            String dimensions = Arrays.toString(state.getValue().size());
            Aggregation neuronAggregation = neuron.getAggregation();
            String aggregation = "";
            if (neuronAggregation != null) {
                aggregation = neuron.getAggregation().getName();
            }
            StringBuilder sb = new StringBuilder();
            sb.append("\"");
            sb.append(name).append("\n");
            sb.append("val: ").append(value).append("\n");
            sb.append("grad: ").append(gradient).append("\n");
            sb.append("dim: ").append(dimensions).append("\n");
            sb.append("fcn: ").append(aggregation).append("\n");
            sb.append("\"");
            return sb.toString();
        }

        private String getEdgeLabel(int from, Integer to, Weight weight) {
            StringBuilder sb = new StringBuilder();
            sb.append(" [label=");
            sb.append("\"");
            sb.append(weight.index).append(":");
            sb.append(weight.name).append(":");
            sb.append(Arrays.toString(weight.value.size())).append(":");
            sb.append(weight.value.toString());
            sb.append("\"");
            if (weight.isFixed) {
                sb.append(", style=dashed ");
            }
            sb.append("]");
            return sb.toString();
        }

        private <T extends Neurons, S extends State.Neural> String getEdges(BaseNeuron<T, S> neuron) {
            Iterator<T> inputs = this.neuralNetwork.getInputs(neuron);
            StringBuilder sb = new StringBuilder();
            while (inputs.hasNext()) {
                Neurons input = (Neurons)inputs.next();
                sb.append(neuron.index + " -> " + input.getIndex() + " [style=dashed] ").append("\n");
            }
            return sb.toString();
        }

        private <T extends Neurons, S extends State.Neural> String getEdges(WeightedNeuron<T, S> neuron) {
            Pair<Iterator<T>, Iterator<Weight>> inputs = this.neuralNetwork.getInputs(neuron);
            Iterator inputNeurons = (Iterator)inputs.r;
            Iterator inputWeights = (Iterator)inputs.s;
            StringBuilder sb = new StringBuilder();
            while (inputNeurons.hasNext()) {
                Neurons input = (Neurons)inputNeurons.next();
                Weight weight = (Weight)inputWeights.next();
                sb.append(neuron.index + " -> " + input.getIndex() + this.getEdgeLabel(neuron.index, input.getIndex(), weight)).append("\n");
            }
            return sb.toString();
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(BaseNeuron<T, S> neuron) {
            LOG.severe("DDD while drawing");
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(WeightedNeuron<T, S> neuron) {
            LOG.severe("DDD while drawing");
        }

        public void visit(AtomNeuron neuron) {
            this.gv.addln(neuron.index + " [shape=ellipse, color=blue, label=" + this.getNeuronLabel(neuron) + "]");
            this.gv.addln(this.getEdges(neuron));
        }

        public void visit(WeightedAtomNeuron neuron) {
            this.gv.addln(neuron.index + " [shape=ellipse, color=blue, label=" + this.getNeuronLabel(neuron) + "]");
            this.gv.addln(this.getEdges(neuron));
        }

        public void visit(AggregationNeuron neuron) {
            this.gv.addln(neuron.index + " [shape=box, color=green, label=" + this.getNeuronLabel(neuron) + "]");
            this.gv.addln(this.getEdges(neuron));
        }

        public void visit(RuleNeuron neuron) {
            this.gv.addln(neuron.index + " [shape=ellipse, color=red, label=" + this.getNeuronLabel(neuron) + "]");
            this.gv.addln(this.getEdges(neuron));
        }

        public void visit(WeightedRuleNeuron neuron) {
            this.gv.addln(neuron.index + " [shape=ellipse, color=red, label=" + this.getNeuronLabel(neuron) + "]");
            this.gv.addln(this.getEdges(neuron));
        }

        @Override
        public void visit(FactNeuron neuron) {
            this.gv.addln(neuron.index + " [shape=house, color=black, label=" + this.getNeuronLabel(neuron) + "]");
            this.gv.addln();
        }
    }
}

