/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.gene.tetrad.gene.history;

import edu.cmu.tetrad.gene.tetrad.gene.history.BasicLagGraph;
import edu.cmu.tetrad.gene.tetrad.gene.history.IndexedLagGraph;
import edu.cmu.tetrad.gene.tetrad.gene.history.LagGraph;
import edu.cmu.tetrad.gene.tetrad.gene.history.LaggedFactor;
import edu.cmu.tetrad.gene.tetrad.gene.history.Polynomial;
import edu.cmu.tetrad.gene.tetrad.gene.history.PolynomialFunction;
import edu.cmu.tetrad.gene.tetrad.gene.history.PolynomialTerm;
import edu.cmu.tetrad.gene.tetrad.gene.history.UpdateFunction;
import edu.cmu.tetrad.util.dist.Distribution;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;

public class LinearFunction
implements UpdateFunction {
    static final long serialVersionUID = 23L;
    private final PolynomialFunction polynomialFunction;

    public LinearFunction(LagGraph lagGraph) {
        if (lagGraph == null) {
            throw new NullPointerException("Lag graph must not be null.");
        }
        this.polynomialFunction = new PolynomialFunction(lagGraph);
        IndexedLagGraph connectivity = this.polynomialFunction.getIndexedLagGraph();
        for (int i = 0; i < connectivity.getNumFactors(); ++i) {
            ArrayList<PolynomialTerm> terms = new ArrayList<PolynomialTerm>();
            terms.add(new PolynomialTerm(0.0, new int[0]));
            int numParents = connectivity.getNumParents(i);
            int j = 0;
            while (j < numParents) {
                int[] vars = new int[]{j++};
                terms.add(new PolynomialTerm(1.0 / (double)numParents, vars));
            }
            Polynomial p = new Polynomial(terms);
            this.polynomialFunction.setPolynomial(i, p);
        }
    }

    public static LinearFunction serializableInstance() {
        return new LinearFunction(BasicLagGraph.serializableInstance());
    }

    @Override
    public double getValue(int factorIndex, double[][] history) {
        return this.polynomialFunction.getValue(factorIndex, history);
    }

    @Override
    public IndexedLagGraph getIndexedLagGraph() {
        return this.polynomialFunction.getIndexedLagGraph();
    }

    public boolean setIntercept(String factor, double intercept) {
        IndexedLagGraph connectivity = this.polynomialFunction.getIndexedLagGraph();
        int factorIndex = connectivity.getIndex(factor);
        return this.setIntercept(factorIndex, intercept);
    }

    public boolean setIntercept(int factor, double intercept) {
        Polynomial p = this.polynomialFunction.getPolynomial(factor);
        PolynomialTerm term = p.findTerm(new int[0]);
        if (term == null) {
            return false;
        }
        term.setCoefficient(intercept);
        return true;
    }

    public boolean setCoefficient(String factor, LaggedFactor parent, double intercept) {
        IndexedLagGraph connectivity = this.polynomialFunction.getIndexedLagGraph();
        int factorIndex = connectivity.getIndex(factor);
        int parentIndex = connectivity.getIndex(factor, parent);
        return this.setCoefficient(factorIndex, parentIndex, intercept);
    }

    public boolean setCoefficient(int factor, int parent, double coefficient) {
        Polynomial p = this.polynomialFunction.getPolynomial(factor);
        PolynomialTerm term = p.findTerm(new int[]{parent});
        if (term == null) {
            return false;
        }
        term.setCoefficient(coefficient);
        return true;
    }

    public void setErrorDistribution(int factor, Distribution distribution) {
        this.polynomialFunction.setErrorDistribution(factor, distribution);
    }

    public Distribution getErrorDistribution(int factor) {
        return this.polynomialFunction.getErrorDistribution(factor);
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        IndexedLagGraph connectivity = this.polynomialFunction.getIndexedLagGraph();
        buf.append("\n\nLinear Function:");
        for (int i = 0; i < connectivity.getNumFactors(); ++i) {
            buf.append("\n\tFactor ").append(connectivity.getFactor(i)).append(" --> ").append(this.polynomialFunction.getPolynomial(i));
        }
        return buf.toString();
    }

    @Override
    public int getNumFactors() {
        return this.polynomialFunction.getNumFactors();
    }

    @Override
    public int getMaxLag() {
        return this.polynomialFunction.getMaxLag();
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.polynomialFunction == null) {
            throw new NullPointerException();
        }
    }
}

