/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.utils;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.score.ImagesScore;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.SemBicScore;
import edu.cmu.tetrad.search.utils.TsUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.SublistGenerator;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class ShiftSearch {
    private final List<DataModel> dataSets;
    private int maxShift = 2;
    private Knowledge knowledge = new Knowledge();
    private int c = 4;
    private int maxNumShifts;
    private PrintStream out = System.out;
    private boolean scheduleStop;
    private boolean forwardSearch;
    private boolean precomputeCovariances = false;

    public ShiftSearch(List<DataModel> dataSets) {
        this.dataSets = dataSets;
    }

    public int[] search() {
        int[] choice;
        if (this.maxShift < 1) {
            throw new IllegalStateException("Max shift should be >= 1: " + this.maxShift);
        }
        int numVars = ((DataSet)this.dataSets.get(0)).getNumColumns();
        List<Node> nodes = this.dataSets.get(0).getVariables();
        int[] bestshifts = new int[numVars];
        int maxNumRows = ((DataSet)this.dataSets.get(0)).getNumRows() - this.maxShift;
        double b = this.getAvgBic(this.dataSets);
        this.printShifts(bestshifts, b, nodes);
        SublistGenerator generator = new SublistGenerator(nodes.size(), this.getMaxNumShifts());
        while ((choice = generator.next()) != null) {
            int[] shifts = new int[nodes.size()];
            double zSize = FastMath.pow((double)this.getMaxShift(), choice.length);
            int iIndex = this.dataSets.get(0).getVariables().indexOf(this.dataSets.get(0).getVariable("I"));
            int z = 0;
            while ((double)z < zSize && !this.scheduleStop) {
                int _z = z;
                for (int j : choice) {
                    if (j == iIndex) continue;
                    shifts[j] = _z % this.getMaxShift() + 1;
                    if (!this.forwardSearch) {
                        shifts[j] = -shifts[j];
                    }
                    _z /= this.getMaxShift();
                }
                List<DataModel> _shiftedDataSets = this.getShiftedDataSets(shifts, maxNumRows);
                double _b = this.getAvgBic(_shiftedDataSets);
                if (_b < 0.999 * b) {
                    b = _b;
                    this.printShifts(shifts, b, nodes);
                    System.arraycopy(shifts, 0, bestshifts, 0, shifts.length);
                }
                ++z;
            }
        }
        this.println("\nShifts with the lowest BIC score: ");
        this.printShifts(bestshifts, b, nodes);
        return bestshifts;
    }

    public int getMaxShift() {
        return this.maxShift;
    }

    public void setMaxShift(int maxShift) {
        this.maxShift = maxShift;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public int getC() {
        return this.c;
    }

    public void setC(int c) {
        this.c = c;
    }

    public int getMaxNumShifts() {
        return this.maxNumShifts;
    }

    public void setMaxNumShifts(int maxNumShifts) {
        this.maxNumShifts = maxNumShifts;
    }

    public void setOut(OutputStream out) {
        this.out = new PrintStream(out);
    }

    public void stop() {
        this.scheduleStop = true;
    }

    public void setForwardSearch(boolean forwardSearch) {
        this.forwardSearch = forwardSearch;
    }

    private void printShifts(int[] shifts, double b, List<Node> nodes) {
        StringBuilder buf = new StringBuilder();
        for (int i = 0; i < shifts.length; ++i) {
            buf.append(nodes.get(i)).append("=").append(shifts[i]).append(" ");
        }
        buf.append(b);
        this.println(buf.toString());
    }

    private void println(String s) {
        System.out.println(s);
        if (this.out != null) {
            this.out.println(s);
            this.out.flush();
        }
    }

    private List<DataModel> getShiftedDataSets(int[] shifts, int maxNumRows) {
        ArrayList<DataModel> shiftedDataSets2 = new ArrayList<DataModel>();
        for (DataModel dataSet : this.dataSets) {
            DataSet shiftedData = TsUtils.createShiftedData((DataSet)dataSet, shifts);
            shiftedDataSets2.add(shiftedData);
        }
        return this.ensureNumRows(shiftedDataSets2, maxNumRows);
    }

    private List<DataModel> ensureNumRows(List<DataModel> dataSets, int numRows) {
        ArrayList<DataModel> truncatedData = new ArrayList<DataModel>();
        for (DataModel _dataSet : dataSets) {
            DataSet dataSet = (DataSet)_dataSet;
            Matrix mat = dataSet.getDoubleData();
            Matrix mat2 = mat.getPart(0, numRows - 1, 0, mat.getNumColumns() - 1);
            truncatedData.add(new BoxDataSet(new DoubleDataBox(mat2.toArray()), dataSet.getVariables()));
        }
        return truncatedData;
    }

    private double getAvgBic(List<DataModel> dataSets) {
        ArrayList<Score> scores = new ArrayList<Score>();
        for (DataModel dataSet : dataSets) {
            SemBicScore _score = new SemBicScore((DataSet)dataSet, this.precomputeCovariances);
            scores.add(_score);
        }
        ImagesScore imagesScore = new ImagesScore(scores);
        Fges images = new Fges(imagesScore);
        images.setKnowledge(this.knowledge);
        images.search();
        return -images.getModelScore() / (double)dataSets.size();
    }

    public void setPrecomputeCovariances(boolean precomputeCovariances) {
        this.precomputeCovariances = precomputeCovariances;
    }
}

