/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Images;
import edu.cmu.tetrad.search.TimeSeriesUtils;
import edu.cmu.tetrad.util.DepthChoiceGenerator;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;

public class ShiftSearch {
    private List<DataSet> dataSets;
    private int maxShift = 2;
    private Knowledge knowledge = new Knowledge();
    private int c = 4;
    private int maxNumShifts;
    private PrintStream out = System.out;
    private boolean scheduleStop = false;
    private boolean forwardSearch;

    public ShiftSearch(List<DataSet> dataSets) {
        this(dataSets, null);
    }

    public ShiftSearch(List<DataSet> dataSets, Graph measuredDag) {
        this.dataSets = dataSets;
    }

    public int[] search() {
        int[] choice;
        if (this.maxShift < 1) {
            throw new IllegalStateException("Max shift should be >= 1: " + this.maxShift);
        }
        int numVars = this.dataSets.get(0).getNumColumns();
        List<Node> nodes = this.dataSets.get(0).getVariables();
        int[] bestshifts = new int[numVars];
        int maxNumRows = this.dataSets.get(0).getNumRows() - this.maxShift;
        double b = this.getAvgBic(this.dataSets);
        this.printShifts(bestshifts, b, nodes);
        DepthChoiceGenerator generator = new DepthChoiceGenerator(nodes.size(), this.getMaxNumShifts());
        while ((choice = generator.next()) != null) {
            int[] shifts = new int[nodes.size()];
            double zSize = Math.pow(this.getMaxShift(), choice.length);
            int iIndex = this.dataSets.get(0).getVariables().indexOf(this.dataSets.get(0).getVariable("I"));
            int z = 0;
            while ((double)z < zSize && !this.scheduleStop) {
                int _z = z;
                for (int i = 0; i < choice.length; ++i) {
                    if (choice[i] == iIndex) continue;
                    shifts[choice[i]] = _z % this.getMaxShift() + 1;
                    if (!this.forwardSearch) {
                        shifts[choice[i]] = -shifts[choice[i]];
                    }
                    _z /= this.getMaxShift();
                }
                List<DataSet> _shiftedDataSets = this.getShiftedDataSets(shifts, maxNumRows);
                double _b = this.getAvgBic(_shiftedDataSets);
                if (_b < 0.999 * b) {
                    b = _b;
                    this.printShifts(shifts, b, nodes);
                    System.arraycopy(shifts, 0, bestshifts, 0, shifts.length);
                }
                ++z;
            }
        }
        this.println("\nShifts with the lowest BIC score: ");
        this.printShifts(bestshifts, b, nodes);
        return bestshifts;
    }

    private void printShifts(int[] shifts, double b, List<Node> nodes) {
        StringBuilder buf = new StringBuilder();
        for (int i = 0; i < shifts.length; ++i) {
            buf.append(nodes.get(i) + "=" + shifts[i] + " ");
        }
        buf.append(b);
        this.println(buf.toString());
    }

    private void println(String s) {
        System.out.println(s);
        if (this.out != null) {
            this.out.println(s);
            this.out.flush();
        }
    }

    private List<DataSet> getShiftedDataSets(int[] shifts, int maxNumRows) {
        ArrayList<DataSet> shiftedDataSets2 = new ArrayList<DataSet>();
        for (DataSet dataSet : this.dataSets) {
            DataSet shiftedData = TimeSeriesUtils.createShiftedData(dataSet, shifts);
            shiftedDataSets2.add(shiftedData);
        }
        return this.ensureNumRows(shiftedDataSets2, maxNumRows);
    }

    private List<DataSet> truncateDataSets(List<DataSet> dataSets, int topMargin, int bottomMargin) {
        ArrayList<DataSet> truncatedData = new ArrayList<DataSet>();
        for (DataSet dataSet : dataSets) {
            DoubleMatrix2D mat = dataSet.getDoubleData();
            DoubleMatrix2D mat2 = mat.viewPart(topMargin, 0, mat.rows() - topMargin - bottomMargin, mat.columns());
            truncatedData.add(ColtDataSet.makeContinuousData(dataSet.getVariables(), mat2));
        }
        return truncatedData;
    }

    private List<DataSet> ensureNumRows(List<DataSet> dataSets, int numRows) {
        ArrayList<DataSet> truncatedData = new ArrayList<DataSet>();
        for (DataSet dataSet : dataSets) {
            DoubleMatrix2D mat = dataSet.getDoubleData();
            DoubleMatrix2D mat2 = mat.viewPart(0, 0, numRows, mat.columns());
            truncatedData.add(ColtDataSet.makeContinuousData(dataSet.getVariables(), mat2));
        }
        return truncatedData;
    }

    private double getAvgBic(List<DataSet> dataSets) {
        Images images = new Images(dataSets);
        images.setPenaltyDiscount(this.c);
        images.setKnowledge(this.knowledge);
        Graph pattern = images.search();
        return -images.getModelScore() / (double)dataSets.size();
    }

    public int getMaxShift() {
        return this.maxShift;
    }

    public void setMaxShift(int maxShift) {
        this.maxShift = maxShift;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
    }

    public int getC() {
        return this.c;
    }

    public void setC(int c) {
        this.c = c;
    }

    public int getMaxNumShifts() {
        return this.maxNumShifts;
    }

    public void setMaxNumShifts(int maxNumShifts) {
        this.maxNumShifts = maxNumShifts;
    }

    public void setOut(OutputStream out) {
        this.out = new PrintStream(out);
    }

    public void stop() {
        this.scheduleStop = true;
    }

    public void setForwardSearch(boolean forwardSearch) {
        this.forwardSearch = forwardSearch;
    }
}

