/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class SemOptimizerScattershot
implements SemOptimizer {
    private static final long serialVersionUID = 23L;
    private int numRestarts;

    public static SemOptimizerScattershot serializableInstance() {
        return new SemOptimizerScattershot();
    }

    @Override
    public void optimize(SemIm semIm) {
        Matrix sampleCovar = semIm.getSampleCovar();
        if (sampleCovar == null) {
            throw new NullPointerException("Sample covar has not been set.");
        }
        if (DataUtils.containsMissingValue(sampleCovar)) {
            throw new IllegalArgumentException("Please remove or impute missing values.");
        }
        if (DataUtils.containsMissingValue(sampleCovar)) {
            throw new IllegalArgumentException("Please remove or impute missing values.");
        }
        if (this.numRestarts < 1) {
            this.numRestarts = 1;
        }
        TetradLogger.getInstance().log("info", "Trying EM...");
        TetradLogger.getInstance().log("info", "Trying scattershot...");
        double min = Double.POSITIVE_INFINITY;
        SemIm _sem = null;
        for (int i = 0; i < this.numRestarts + 1; ++i) {
            TetradLogger.getInstance().log("details", "Trial " + (i + 1));
            SemIm _sem2 = new SemIm(semIm);
            this.optimize2(_sem2);
            double chisq = _sem2.getChiSquare();
            if (!(FastMath.abs(chisq) < min)) continue;
            min = FastMath.abs(chisq);
            _sem = _sem2;
        }
        if (_sem == null) {
            throw new NullPointerException("Minimal score SEM could not be found.");
        }
        for (Parameter param : semIm.getFreeParameters()) {
            Node nodeA = param.getNodeA();
            Node nodeB = param.getNodeB();
            Node _nodeA = _sem.getVariableNode(nodeA.getName());
            Node _nodeB = _sem.getVariableNode(nodeB.getName());
            double value = _sem.getParamValue(_nodeA, _nodeB);
            semIm.setParamValue(param, value);
        }
    }

    @Override
    public int getNumRestarts() {
        return this.numRestarts;
    }

    @Override
    public void setNumRestarts(int numRestarts) {
        this.numRestarts = numRestarts;
    }

    public String toString() {
        return "Sem Optimizer Scattershot";
    }

    private void optimize2(SemIm semIm) {
        SemFittingFunction f = new SemFittingFunction(semIm);
        double[] p = semIm.getFreeParamValues();
        f.setAvoidNegativeVariances(true);
        this.iterateFindLowerRandom(f, p, 1.0, 1500);
        this.iterateFindLowerRandom(f, p, 0.5, 500);
        this.iterateFindLowerRandom(f, p, 0.25, 500);
        this.iterateFindLowerRandom(f, p, 0.1, 500);
        this.iterateFindLowerRandom(f, p, 0.1, 500);
        this.iterateFindLowerRandom(f, p, 0.05, 500);
        this.iterateFindLowerRandom(f, p, 0.01, 500);
        this.iterateFindLowerRandom(f, p, 0.005, 50);
        this.iterateFindLowerRandom(f, p, 0.001, 50);
        this.iterateFindLowerRandom(f, p, 5.0E-4, 50);
        this.iterateFindLowerRandom(f, p, 1.0E-4, 50);
        semIm.setFreeParamValues(p);
    }

    private void iterateFindLowerRandom(FittingFunction fcn, double[] p, double range, int iterations) {
        int t = 0;
        while (++t < 2000) {
            boolean found;
            try {
                found = this.findLowerRandom(fcn, p, range, iterations);
            }
            catch (Exception e) {
                return;
            }
            if (found) continue;
            return;
        }
    }

    private boolean findLowerRandom(FittingFunction fcn, double[] p, double width, int numPoints) {
        double fP = fcn.evaluate(p);
        if (Double.isNaN(fP)) {
            throw new IllegalArgumentException("Center point must evaluate!");
        }
        double[] fixedP = new double[p.length];
        System.arraycopy(p, 0, fixedP, 0, p.length);
        double[] pTemp = new double[p.length];
        System.arraycopy(p, 0, pTemp, 0, p.length);
        for (int i = 0; i < numPoints; ++i) {
            this.randomPointAboutCenter(pTemp, fixedP, width);
            double f = fcn.evaluate(pTemp);
            if (f == Double.POSITIVE_INFINITY) {
                --i;
                continue;
            }
            if (width == 1.0) {
                int t = 0;
                while (++t < 2000 && this.findLowerRandomLocal(fcn, pTemp)) {
                }
            }
            if (!(f < fP)) continue;
            System.arraycopy(pTemp, 0, p, 0, pTemp.length);
            TetradLogger.getInstance().log("optimization", "Cube width = " + width + " FML = " + f);
            return true;
        }
        return false;
    }

    private boolean findLowerRandomLocal(FittingFunction fcn, double[] p) {
        double fP = fcn.evaluate(p);
        if (Double.isNaN(fP)) {
            throw new IllegalArgumentException("Center point must evaluate!");
        }
        double[] fixedP = new double[p.length];
        System.arraycopy(p, 0, fixedP, 0, p.length);
        double[] pTemp = new double[p.length];
        System.arraycopy(p, 0, pTemp, 0, p.length);
        for (int i = 0; i < 10; ++i) {
            this.randomPointAboutCenter(pTemp, fixedP, 0.2);
            double f = fcn.evaluate(pTemp);
            if (f == Double.POSITIVE_INFINITY) {
                ++i;
                continue;
            }
            if (!(f < fP)) continue;
            System.arraycopy(pTemp, 0, p, 0, pTemp.length);
            TetradLogger.getInstance().log("optimization", "Cube width = 0.2 FML = " + f);
            return true;
        }
        return false;
    }

    private void randomPointAboutCenter(double[] pTemp, double[] fixedP, double width) {
        for (int j = 0; j < pTemp.length; ++j) {
            double v = this.getRandom().nextDouble();
            pTemp[j] = fixedP[j] + (-width / 2.0 + width * v);
        }
    }

    private RandomUtil getRandom() {
        return RandomUtil.getInstance();
    }

    static class SemFittingFunction
    implements FittingFunction {
        private final SemIm sem;
        private final List<Parameter> freeParameters;
        private boolean avoidNegativeVariances;

        public SemFittingFunction(SemIm sem) {
            this.sem = sem;
            this.freeParameters = sem.getFreeParameters();
        }

        @Override
        public double evaluate(double[] parameters) {
            this.sem.setFreeParamValues(parameters);
            for (double parameter : parameters) {
                if (!Double.isNaN(parameter) && !Double.isInfinite(parameter)) continue;
                return Double.POSITIVE_INFINITY;
            }
            double fml = this.sem.getScore();
            if (Double.isNaN(fml) || Double.isInfinite(fml)) {
                return Double.POSITIVE_INFINITY;
            }
            if (this.avoidNegativeVariances) {
                for (int i = 0; i < parameters.length; ++i) {
                    if (this.freeParameters.get(i).getType() != ParamType.VAR || !(parameters[i] <= 0.0)) continue;
                    return Double.POSITIVE_INFINITY;
                }
            }
            if (Double.isNaN(fml)) {
                return Double.POSITIVE_INFINITY;
            }
            if (fml < 0.0) {
                return Double.POSITIVE_INFINITY;
            }
            return fml;
        }

        @Override
        public void setAvoidNegativeVariances(boolean avoidNegativeVariances) {
            this.avoidNegativeVariances = avoidNegativeVariances;
        }
    }

    static interface FittingFunction {
        public double evaluate(double[] var1);

        public void setAvoidNegativeVariances(boolean var1);
    }
}

