/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.study.gene.tetrad.gene.history;

import edu.cmu.tetrad.study.gene.tetrad.gene.history.BasicLagGraph;
import edu.cmu.tetrad.study.gene.tetrad.gene.history.BooleanFunction;
import edu.cmu.tetrad.study.gene.tetrad.gene.history.IndexedLagGraph;
import edu.cmu.tetrad.study.gene.tetrad.gene.history.IndexedParent;
import edu.cmu.tetrad.study.gene.tetrad.gene.history.LagGraph;
import edu.cmu.tetrad.study.gene.tetrad.gene.history.UpdateFunction;
import edu.cmu.tetrad.util.dist.Distribution;
import edu.cmu.tetrad.util.dist.Normal;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import org.apache.commons.math3.util.FastMath;

public class BooleanGlassFunction
implements UpdateFunction {
    private static final long serialVersionUID = 23L;
    private final IndexedLagGraph connectivity;
    private final BooleanFunction[] booleanFunctions;
    private final Distribution[] errorDistributions;
    private double lowerBound;
    private double basalExpression;
    private double decayRate;
    private double booleanInfluenceRate;

    public BooleanGlassFunction(LagGraph lagGraph) {
        this(lagGraph, Double.NEGATIVE_INFINITY, 0.0);
    }

    public BooleanGlassFunction(LagGraph lagGraph, double lowerBound, double basalExpression) {
        int i;
        if (lagGraph == null) {
            throw new NullPointerException("Graph must not be null.");
        }
        if (lowerBound >= basalExpression) {
            throw new IllegalArgumentException("Lower bound must be less than basal expression.");
        }
        this.lowerBound = lowerBound;
        this.basalExpression = basalExpression;
        this.connectivity = new IndexedLagGraph(lagGraph, true);
        this.errorDistributions = new Distribution[this.connectivity.getNumFactors()];
        for (i = 0; i < this.errorDistributions.length; ++i) {
            this.errorDistributions[i] = new Normal(0.0, 0.05);
        }
        this.booleanFunctions = new BooleanFunction[this.connectivity.getNumFactors()];
        for (i = 0; i < this.booleanFunctions.length; ++i) {
            if (this.connectivity.getNumParents(i) > 0) {
                ArrayList<IndexedParent> parentList = new ArrayList<IndexedParent>();
                for (int j = 0; j < this.connectivity.getNumParents(i); ++j) {
                    IndexedParent parent = this.connectivity.getParent(i, j);
                    parentList.add(parent);
                }
                IndexedParent[] parents = parentList.toArray(new IndexedParent[0]);
                this.booleanFunctions[i] = new BooleanFunction(parents);
                do {
                    this.booleanFunctions[i].randomize();
                } while (!this.booleanFunctions[i].isEffective());
                continue;
            }
            this.booleanFunctions[i] = null;
        }
        this.setDecayRate(0.1);
        this.setBooleanInfluenceRate(0.5);
    }

    public static BooleanGlassFunction serializableInstance() {
        return new BooleanGlassFunction(BasicLagGraph.serializableInstance());
    }

    @Override
    public IndexedLagGraph getIndexedLagGraph() {
        return this.connectivity;
    }

    public double getBasalExpression() {
        return this.basalExpression;
    }

    public void setBasalExpression(double basalExpression) {
        this.basalExpression = basalExpression;
    }

    @Override
    public double getValue(int factor, double[][] history) {
        double v0 = history[1][factor];
        double v1 = -this.decayRate * (v0 - this.basalExpression);
        double v2 = this.booleanInfluenceRate * this.getFValue(factor, history);
        double v3 = this.errorDistributions[factor].nextRandom();
        double v4 = v0 + v1 + v2 + v3;
        return FastMath.max(this.lowerBound, v4);
    }

    public double getFValue(int factor, double[][] history) {
        if (this.booleanFunctions[factor] == null) {
            return 0.0;
        }
        BooleanFunction booleanFunction = this.booleanFunctions[factor];
        Object[] parents = booleanFunction.getParents();
        boolean[] parentValues = new boolean[parents.length];
        for (int i = 0; i < parentValues.length; ++i) {
            IndexedParent parent = (IndexedParent)parents[i];
            double histVal = history[parent.getLag()][parent.getIndex()];
            parentValues[i] = histVal > this.basalExpression;
        }
        int row = booleanFunction.getRow(parentValues);
        boolean functionValue = booleanFunction.getValue(row);
        double trueValue = 1.0;
        double falseValue = -1.0;
        return functionValue ? trueValue : falseValue;
    }

    public BooleanFunction getSubFunction(int factor) {
        return this.booleanFunctions[factor];
    }

    public double getDecayRate() {
        return this.decayRate;
    }

    public void setDecayRate(double decayRate) {
        if (decayRate <= 0.0 || decayRate > 1.0) {
            throw new IllegalArgumentException("Suggested rate out of bounds (0.0 <= decayRate < 1.0): " + decayRate);
        }
        this.decayRate = decayRate;
    }

    public double getBooleanInfluenceRate() {
        return this.booleanInfluenceRate;
    }

    public void setBooleanInfluenceRate(double booleanInfluenceRate) {
        if (booleanInfluenceRate <= 0.0) {
            throw new IllegalArgumentException("Suggested rate out of bounds (0.0 <= booleanInfluenceRate): " + booleanInfluenceRate);
        }
        this.booleanInfluenceRate = booleanInfluenceRate;
    }

    public void setLowerBound(double lowerBound) {
        this.lowerBound = lowerBound;
    }

    public void setErrorDistribution(int factor, Distribution distribution) {
        if (distribution == null) {
            throw new NullPointerException();
        }
        this.errorDistributions[factor] = distribution;
    }

    public Distribution getErrorDistribution(int factor) {
        return this.errorDistributions[factor];
    }

    @Override
    public int getNumFactors() {
        return this.connectivity.getNumFactors();
    }

    @Override
    public int getMaxLag() {
        int maxLag = 0;
        for (int i = 0; i < this.connectivity.getNumFactors(); ++i) {
            for (int j = 0; j < this.connectivity.getNumParents(i); ++j) {
                IndexedParent parent = this.connectivity.getParent(i, j);
                if (parent.getLag() <= maxLag) continue;
                maxLag = parent.getLag();
            }
        }
        return maxLag;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.connectivity == null) {
            throw new NullPointerException();
        }
        if (this.booleanFunctions == null) {
            throw new NullPointerException();
        }
        if (this.errorDistributions == null) {
            throw new NullPointerException();
        }
        if (this.lowerBound >= this.basalExpression) {
            throw new IllegalStateException();
        }
        if (this.decayRate <= 0.0 || this.decayRate > 1.0) {
            throw new IllegalStateException();
        }
        if (this.booleanInfluenceRate <= 0.0) {
            throw new IllegalStateException();
        }
    }
}

