/*
 * Decompiled with CFR 0.152.
 */
package islab.bayesian;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import islab.bayesian.DataSet;
import islab.bayesian.IProvideMean;
import islab.lib.XmlHelper;
import islab.lib.XmlXomReader;
import nu.xom.Document;
import nu.xom.Element;
import org.xml.sax.SAXParseException;

public class LinearModel
implements IProvideMean {
    private DoubleMatrix1D b;
    private DoubleMatrix1D x;
    private int nParents;

    public LinearModel(int nParents) {
        this.nParents = nParents;
        this.b = null;
        this.x = new DenseDoubleMatrix1D(nParents + 1);
        this.x.set(0, 1.0);
    }

    public LinearModel(DataSet data, int[] rows, int depVar, int[] indexIncoming) {
        this(indexIncoming.length);
        this.train(data, rows, depVar, indexIncoming);
    }

    public LinearModel(DataSet data, int depVar, int[] indexIncoming) {
        this(indexIncoming.length);
        int[] rows = new int[data.dataset().length];
        int i = 0;
        while (i < rows.length) {
            rows[i] = i;
            ++i;
        }
        this.train(data, rows, depVar, indexIncoming);
    }

    public void setB(DoubleMatrix1D b) {
        this.b = b;
        if (b.size() != this.nParents + 1) {
            throw new RuntimeException("b's size doesn't match");
        }
    }

    public DoubleMatrix1D getB() {
        return this.b;
    }

    public double computeMean(double[] configuration, int[] indexIncoming) {
        if (indexIncoming == null && this.nParents == 0) {
            return this.b.get(0);
        }
        if (indexIncoming != null && this.nParents != indexIncoming.length) {
            throw new RuntimeException("expected nParents == indexIncoming");
        }
        if (this.b == null) {
            throw new RuntimeException("model not yet trained");
        }
        int i = 0;
        while (i < this.nParents) {
            this.x.set(i + 1, configuration[indexIncoming[i]]);
            ++i;
        }
        double result = this.b.zDotProduct(this.x);
        return result;
    }

    public boolean train(DataSet dataset, int[] rows, int depVar, int[] indexIncoming) {
        DoubleMatrix1D bTmp;
        int n = rows.length;
        int p = indexIncoming.length;
        double[][] data = dataset.dataset();
        DenseDoubleMatrix2D X = new DenseDoubleMatrix2D(n, p + 1);
        int i = 0;
        while (i < n) {
            int j = 0;
            while (j < p + 1) {
                X.set(i, j, j == 0 ? 1.0 : data[rows[i]][indexIncoming[j - 1]]);
                ++j;
            }
            ++i;
        }
        DenseDoubleMatrix1D y = new DenseDoubleMatrix1D(n);
        int i2 = 0;
        while (i2 < n) {
            y.set(i2, data[rows[i2]][depVar]);
            ++i2;
        }
        DoubleMatrix2D XtXinv = Algebra.DEFAULT.inverse(((DoubleMatrix2D)X).zMult(X, null, 1.0, 0.0, true, false));
        DoubleMatrix1D Xty = ((DoubleMatrix2D)X).zMult(y, null, 1.0, 0.0, true);
        this.b = bTmp = XtXinv.zMult(Xty, null);
        return true;
    }

    public double RMSE(DataSet dataset, int[] rows, int depVar, int[] indexIncoming) {
        double rmse = 0.0;
        int n = rows.length;
        double[][] data = dataset.dataset();
        int i = 0;
        while (i < n) {
            double measured = data[rows[i]][depVar];
            double predicted = this.computeMean(data[rows[i]], indexIncoming);
            rmse += (measured - predicted) * (measured - predicted);
            ++i;
        }
        rmse /= (double)n;
        rmse = Math.sqrt(rmse);
        return rmse;
    }

    public double RMSE(DataSet dataset, int depVar, int[] indexIncoming) {
        int[] rows = new int[dataset.dataset().length];
        int i = 0;
        while (i < rows.length) {
            rows[i] = i;
            ++i;
        }
        return this.RMSE(dataset, rows, depVar, indexIncoming);
    }

    public String toXML(String indentString) {
        StringBuffer sb = new StringBuffer(String.valueOf(indentString) + "<IProvideMean type=\"LinearModel\">\n");
        sb.append(String.valueOf(indentString) + "<b-vector>\n");
        int i = 0;
        while (i < this.b.size()) {
            sb.append(String.valueOf(indentString) + "  " + XmlHelper.tag("double", this.b.get(i)) + "\n");
            ++i;
        }
        sb.append(String.valueOf(indentString) + "</b-vector>\n");
        sb.append("</IProvideMean>\n");
        return sb.toString();
    }

    public static LinearModel fromXML(String xml) throws SAXParseException {
        Document dom = XmlXomReader.getDocument(xml);
        Element xmlNode = dom.getRootElement();
        if (!xmlNode.getLocalName().equals("IProvideMean") || !xmlNode.getAttribute("type").getValue().toString().equals("LinearModel")) {
            throw new SAXParseException("Error parsing xml: expected <IProvideMean type=\"LinearModel\">, but received <" + xmlNode.getLocalName() + " type=\"" + xmlNode.getAttribute("type").getValue() + "\"", null);
        }
        Element bXml = xmlNode.getChildElements("b-vector").get(0);
        int bSize = bXml.getChildElements("double").size();
        LinearModel lm = new LinearModel(bSize - 1);
        DenseDoubleMatrix1D b = new DenseDoubleMatrix1D(bSize);
        int i = 0;
        while (i < bSize) {
            b.set(i, XmlHelper.getDouble(bXml, "double", i));
            ++i;
        }
        lm.setB(b);
        return lm;
    }
}

