import java.util.*;
import java.io.*;

class DPAffinityPropagation extends AffinityPropagation {

    /*****************************************************************/
    /*                     Instance Variables                        */
    /*****************************************************************/
    // Messages currently on the edges.
    double[][] mHijMujT_;  // Messages from c_i to mu_j, transposed.
    double[][] mMujHij_;  // messages from gamma_j to c_i
    double[][] mHijSigi_;
    double[][] mSigiHij_;

    // Messages that have been computed but not applied.
    double[][] mHijMujTNew_; 
    double[][] mMujHijNew_;  
    double[][] mHijSigiNew_;
    double[][] mSigiHijNew_;

    // Checking for convergence
    double maxDiff_;
    boolean converged_;

    // Variables used for caching sorts
    boolean[] dSortedDirty_;  // have mu factor messages been sorted?
    MessageValuePair[][] sortedMHijMujTs_;
    int sortOperations_;

    // Tuneable parameters
    DPAPParameters parameters_;
    double mutexDampingFactor_;
    boolean doFineTuning_;

    /*****************************************************************/
    /*                       Private Classes                         */
    /*****************************************************************/    
    private class MessageValuePair {
        public int index_;
        public double value_;

	// If no arguments, set value to -Inf and index to -1
	public MessageValuePair() {
	    value_ = -Double.MAX_VALUE;
	    index_ = -1;
	}

        public MessageValuePair(int index, double value) {
	    index_ = index;
	    value_ = value;
        }
    }
    
    private class MessageValuePairComparator
        implements Comparator<MessageValuePair> {
        
        public int compare(MessageValuePair m1, MessageValuePair m2) {
            if (m1.value_ < m2.value_) return 1;
            if (m1.value_ == m2.value_) return 0;
            else return -1;
        }
    }

    public MessageValuePair[] SortWithIndices(double[] a) {
	MessageValuePair[] result = new MessageValuePair[a.length];
	for (int i = 0; i < a.length; i++) {
	    result[i] = new MessageValuePair(i, a[i]);
	}
	Arrays.sort(result, new MessageValuePairComparator());

	return result;
    }

    /*****************************************************************/
    /*                            Methods                            */
    /*****************************************************************/

    /**
     * Simple default constructor.
     */
    public DPAffinityPropagation(int numVars) {
       	super(numVars);
	dSortedDirty_ = new boolean[numVars_];

	sortOperations_ = 0;
	doFineTuning_ = true;
	parameters_ = new DPAPParameters();

	converged_ = false;
    }

    /**
     * Construct from parameter file.
     */
    public DPAffinityPropagation(String parameterFile) {
	super(parameterFile);
	parameters_ = new DPAPParameters(parameterFile);
    }

    /**
     * Copy parameters from the parameter file into the actual instance
     * variables that are relevant.
     */
    protected void SyncParameters(Parameters p) {
	parameters_ = (DPAPParameters) p;
	System.out.println("Parameter file: " + parameters_.parameterFile_);
	System.out.println("Syncing DPAP Params: " + parameters_.FileName());
	dampingFactor_ = parameters_.dampingFactor_;
	mutexDampingFactor_ = parameters_.mutexDampingFactor_;
    }

    /**
     * Keep track of the name of the algorithm being implemented.
     */
    public String FullName() { 
	return "Dirichlet Process Affinity Propagation"; 
    }
    public String ShortName() { return "DPAP"; }

    /**
     * Actually run inference and come up with the MAP assignment.
     * First, run AP, then if fineTuning_ is on, run Maximizing
     * Gibbs
     */
    public void MAPInference() {
	super.MAPInference();

	if (doFineTuning_) {
	    // Run the Maximizing Gibbs sampler as the final step
	    MaximizingGibbsSampler mg = new MaximizingGibbsSampler(numVars_);
	    mg.LoadSimilaritiesFromSimilarities(similarities_);
	    mg.InitializeToAssignment(CurrentAssignments());
	    mg.MAPInference();
	    tunedAssignment_ = mg.CurrentAssignments();
	}
    }

    /**
     * Do one round of message passing.
     */
    public void PassMessages() {
	maxDiff_ = 0;

	// Compute messages over all mutual exclusion factors.
	for (int i = 0; i < numVars_; i++) {
	    ComputeAllHijSigiMessages(i);
	    UpdateAllHijSigiMessages(i);
	}
	for (int i = 0; i < numVars_; i++) {
	    ComputeAllSigiHijMessages(i);
	    UpdateAllSigiHijMessages(i);
	}

	// Compute messages over exemplar factors.
	for (int j = 0; j < numVars_; j++) {
	    ComputeAllHijMujMessages(j);
	    UpdateAllHijMujMessages(j);
	}
	for (int j = 0; j < numVars_; j++) {
	    ComputeAllMujHijMessages(j);
	    UpdateAllMujHijMessages(j);
	}
	PrintAssignment(CurrentAssignments());

	System.err.println("Max diff: " + maxDiff_);
	if (maxDiff_ < .00001) {
	    converged_ = true;
	}
	iteration_++;
    }

    public boolean HasConverged() {
	return converged_ || Iteration() > 500;
    }

    public boolean MessagesConverged() {
	return converged_;
    }

    /**
     * We use a different set of messages here, so don't bother with the
     * super class's initialization.
     */
    protected void InitializeMessages() {
	// Allocate space for DPAP-specific messages
	mHijMujT_ = new double[numVars_][numVars_] ;
	mMujHij_ = new double[numVars_][numVars_];
	mHijSigi_ = new double[numVars_][numVars_];
	mSigiHij_ = new double[numVars_][numVars_];
	sortedMHijMujTs_ = new MessageValuePair[numVars_][numVars_];

	// Storage for computed but not yet applied messages
	mHijMujTNew_ = new double[numVars_][numVars_];
	mMujHijNew_ = new double[numVars_][numVars_];
	mHijSigiNew_ = new double[numVars_][numVars_];
	mSigiHijNew_ = new double[numVars_][numVars_];

	for (int i = 0; i < numVars_; i++) {
	    for (int j = 0; j < numVars_; j++) {
		mMujHij_[i][j] = 0;
		mSigiHij_[i][j] = 0;
	    }
	}	
    }

    /**
     * Given the current messages on the edges, compute the beliefs for
     * each variable and use it to choose exemplars.
     */
    public ArrayList<Integer> CurrentExemplars() {
	ArrayList<Integer> exemplars = new ArrayList<Integer>();

	for (int i = 0; i < numVars_; i++) {
	    if (Assignment(i) == i) {
		exemplars.add(new Integer(i));
	    }
	}
	return exemplars;
    }

    /**
     * Choose the value with the largest belief for each variable.  Don't
     * worry about whether it forms a legal clustering or not yet.
     */
    public int[] CurrentAssignments() {
	// If we're doing fine tuning and it's been run already, just
	// return the output of that.
	if (doFineTuning_ && tunedAssignment_ != null) {
	    return tunedAssignment_;
	}

	// Otherwise, go through the process of computing beliefs, etc.
	int[] assignment = new int[numVars_];

	for (int i = 0; i < numVars_; i++) {
	    assignment[i] = Assignment(i);
	}

	return assignment;
    }

    /**
     * Ideally there would only be one h_ij whose belief is greater than 0,
     * but if that's not the case, choose the one with the highest belief.
     */
    public int Assignment(int i) {
	double maxBelief = -Double.MAX_VALUE;
	int maxIndex = -1;
	for (int j = 0; j < numVars_; j++) {
	    double b = Belief(i,j);
	    if (b > maxBelief) {
		maxBelief = b;
		maxIndex = j;
	    }
	}
	return maxIndex;
    }

    /**
     * Print a matrix of beliefs.
     */
    public void PrintBeliefs() {
	for (int i = 0; i < numVars_; i++) {
	    for (int j = 0; j < numVars_; j++) {
		System.out.print(twoDigitNF_.format(Belief(i,j)) + " ");
	    }
	    System.out.println();
	}
    }

    /**
     * If this is greater than 0, then we believe that c_i = j.  Otherwise
     * we think c_i != j.
     */
    public double Belief(int i, int j) {
	return similarities_[i][j] + mMujHij_[i][j] + mSigiHij_[i][j];
    }

    // Compute all outgoing messages from a single node at the same
    // time, since computing the first one takes the most work, then
    // the rest can be done in O(N) time each
    protected void ComputeAllMujHijMessages(int j) {
	dSortedDirty_[j] = true;
	for (int i = 0; i < numVars_; i++) {
	    mMujHijNew_[i][j] = MujHijMessage(i, j);
	}
    }

    protected void UpdateAllMujHijMessages(int j) {
	for (int i = 0; i < numVars_; i++) {
	    double oldMessage = mMujHij_[i][j];
	    mMujHij_[i][j] = dampingFactor_ * mMujHij_[i][j] + 
		(1 - dampingFactor_) * mMujHijNew_[i][j];
	    double diff = Math.abs(oldMessage - mMujHij_[i][j]);
	    if (diff > maxDiff_) {
		maxDiff_ = diff;
	    }
	    messagesUpdated_++;
	}
    }

    // A more efficient version.  Rather than enumerating all possible
    // values, use smarter optimization schemes that leverage additional
    // structure in prior (e.g. monotonicity, concavity, convexity, or
    // block structure).
    public double FastMujHijMessage(int i, int j) {
	// If we haven't sorted since updating messages, sort now
	if (dSortedDirty_[j]) {
	    sortedMHijMujTs_[j] = SortWithIndices(mHijMujT_[j]);

	    // cache the function of cumulative sums of sorted messages.  this
	    // is guaranteed to be concave
	    for (int l = 0; l < numVars_; l++) {
		cumsumMHijMujTs_[j][l] = (l > 0) ? cumsumMHijMujTs_[j][l-1] : 0;
		cumsumMHijMujTs_[j][l] += sortedMHijMujTs_[j][l];
	    }
	    cumsumHijMujTs_[j]
	    dSortedDirty_[j] = false;
	    sortOperations_++;
	}

	// We need to compute m(0) and m(1), then what we put on the
	// edge is m(1) - m(0).
	double maxM0 = -Double.MAX_VALUE;
	double maxM1 = -Double.MAX_VALUE;
	double cumLogSum0 = 0;
	double cumLogSum1 = 0;
	double cumMSum = 0;
	int k = 0;

	for (int l = 0; l < numVars_; l++) {
	    int iPrime = sortedMHijMujTs_[j][l].index_;
	    double value = sortedMHijMujTs_[j][l].value_;
	    if (iPrime != i && iPrime != j) {
		// Add in value of next best point to the cluster
		cumMSum += value;

		// Increment cumulative sums of logs
		if (k > 1) cumLogSum0 += Math.log(k);
		if (k > 0) cumLogSum1 += Math.log(k+1);

		// Is this the best setting for m(0) or m(1)?
		double m0K = cumMSum + cumLogSum0 - Math.log(k+1);
		double m1K = cumMSum + cumLogSum1 - Math.log(k+2);
		if (m0K > maxM0) {
		    maxM0 = m0K;
		}
		if (m1K > maxM1) {
		    maxM1 = m1K;
		}

		k++;
	    }
	}

	// Return different values based on whether we're forcing j to be
	// an exemplar by setting c_i=j.
	messagesComputed_++;
	if (i == j) {
	    return maxM0;
	} else {
	    maxM0 = (maxM0 > -mHijMujT_[j][j]) ? maxM0 : -mHijMujT_[j][j];
	    //System.out.println("M1: " + maxM1 + ", M2: " + maxM0);
	    return maxM1 - maxM0;
	    //return mHijMujT_[j][j] + maxM1 - maxM0;
	}
    }

    public double MujHijMessage(int i, int j) {
	// If we haven't sorted since updating messages, sort now
	if (dSortedDirty_[j]) {
	    sortedMHijMujTs_[j] = SortWithIndices(mHijMujT_[j]);
	    dSortedDirty_[j] = false;
	    sortOperations_++;
	}

	// We need to compute m(0) and m(1), then what we put on the
	// edge is m(1) - m(0).
	double maxM0 = -Double.MAX_VALUE;
	double maxM1 = -Double.MAX_VALUE;
	double cumLogSum0 = 0;
	double cumLogSum1 = 0;
	double cumMSum = 0;
	int k = 0;

	for (int l = 0; l < numVars_; l++) {
	    int iPrime = sortedMHijMujTs_[j][l].index_;
	    double value = sortedMHijMujTs_[j][l].value_;
	    if (iPrime != i && iPrime != j) {
		// Add in value of next best point to the cluster
		cumMSum += value;

		// Increment cumulative sums of logs
		if (k > 1) cumLogSum0 += Math.log(k);
		if (k > 0) cumLogSum1 += Math.log(k+1);

		// Is this the best setting for m(0) or m(1)?
		double m0K = cumMSum + cumLogSum0 - Math.log(k+1);
		double m1K = cumMSum + cumLogSum1 - Math.log(k+2);
		if (m0K > maxM0) {
		    maxM0 = m0K;
		}
		if (m1K > maxM1) {
		    maxM1 = m1K;
		}

		k++;
	    }
	}


	// Return different values based on whether we're forcing j to be
	// an exemplar by setting c_i=j.
	messagesComputed_++;
	if (i == j) {
	    return maxM0;
	} else {
	    maxM0 = (maxM0 > -mHijMujT_[j][j]) ? maxM0 : -mHijMujT_[j][j];
	    //System.out.println("M1: " + maxM1 + ", M2: " + maxM0);
	    return maxM1 - maxM0;
	    //return mHijMujT_[j][j] + maxM1 - maxM0;
	}
    }

    // No marginalization needed, so just add the incoming messages to 
    // the h_ij node.
    public void ComputeAllHijMujMessages(int j) {
	for (int i = 0; i < numVars_; i++) {
	    mHijMujTNew_[j][i] = similarities_[i][j] + mSigiHij_[i][j];
	    messagesComputed_++;
	}
    }

    protected void UpdateAllHijMujMessages(int j) {
	for (int i = 0; i < numVars_; i++) {
	    double oldMessage = mHijMujT_[j][i];
	    mHijMujT_[j][i] = dampingFactor_ * mHijMujT_[j][i] +
		(1 - dampingFactor_) * mHijMujTNew_[j][i];
	    double diff = Math.abs(oldMessage - mHijMujT_[j][i]);
	    if (diff > maxDiff_) {
		maxDiff_ = diff;
	    }
	    messagesUpdated_++;
	}
    }

    // We really just need a first and second max, so we might as well
    // share the computation.
    public void ComputeAllSigiHijMessages(int i) {
	// Find the first and second largest values
	double firstMax = -Double.MAX_VALUE;
	int firstMaxIndex = -1;
	double secondMax = -Double.MAX_VALUE;
	for (int j = 0; j < numVars_; j++) {
	    if (mHijSigi_[i][j] > firstMax) {
		secondMax = firstMax;
		firstMax = mHijSigi_[i][j];
		firstMaxIndex = j;
	    } else if (mHijSigi_[i][j] > secondMax) {
		secondMax = mHijSigi_[i][j];
	    }
	}

	// Put the messages on the channel
	for (int j = 0; j < numVars_; j++) {
	    messagesComputed_++;
	    mSigiHijNew_[i][j] = (j == firstMaxIndex) ? 
		-secondMax : -firstMax;
	}
    }
    
    public void UpdateAllSigiHijMessages(int i) {
	for (int j = 0; j < numVars_; j++) {
	    double oldMessage = mSigiHij_[i][j];
	    mSigiHij_[i][j] = mutexDampingFactor_ * mSigiHij_[i][j] +
		(1 - mutexDampingFactor_) * mSigiHijNew_[i][j];

	    double diff = Math.abs(oldMessage - mSigiHij_[i][j]);
	    if (diff > maxDiff_) {
		maxDiff_ = diff;
	    }
	    messagesUpdated_++;
	}
    }

    // This one is easy.  No marginalization needed, so just add the
    // incoming messages to the h_ij node.
    public void ComputeAllHijSigiMessages(int i) {
	for (int j = 0; j < numVars_; j++) {
	    mHijSigiNew_[i][j] = similarities_[i][j] + mMujHij_[i][j];
	    messagesComputed_++;
	}
    }

    public void UpdateAllHijSigiMessages(int i) {
	for (int j = 0; j < numVars_; j++) {
	    double oldMessage = mHijSigi_[i][j];
	    mHijSigi_[i][j] = dampingFactor_ * mHijSigi_[i][j] + 
		(1 - dampingFactor_) * mHijSigiNew_[i][j];

	    double diff = Math.abs(oldMessage - mHijSigi_[i][j]);
	    if (diff > maxDiff_) {
		maxDiff_ = diff;
	    }
	    messagesUpdated_++;
	}
    }

    /**
     * Print statistics about the run.
     */
    public void PrintStats() {
	super.PrintStats();
	System.out.println("Did " + sortOperations_ + " sorts");
    }


    /**
     * Given a filename, load parameters.  The file should be a text file
     * of the following form:
     * <parameter1_name> <parameter1_value>
     * ...
     * <parameterN_name> <parameterN_value>
     */
    public void LoadParametersFromFile(String filename) {
	try {
	    BufferedReader input = 
		new BufferedReader(new FileReader(filename));
	    
	    // There's just one line to read
	    String line ;
	    while ( (line = input.readLine()) != null ) {
		String[] entries = line.trim().split("\\s+");		
		
		// Allow blank lines
		if (entries.length == 0) continue;
		
		// Otherwise, there must be 2 entries per line
		assert(entries.length == 2);
		
		if (entries[0].equals("dampingFactor")) {
		    dampingFactor_ = Double.parseDouble(entries[1]);
		    
		} else if (entries[0].equals("mutexDampingFactor")) {
		    mutexDampingFactor_ = Double.parseDouble(entries[1]);

		} else if (entries[0].equals("doFineTuning")) {
		    if (entries[1].equals("true")) {
			doFineTuning_ = true;
		    } else if (entries[1].equals("false")) {
			doFineTuning_ = false;
		    } else {
			assert false : "Unexpected setting of doFineTuning: " + entries[1]; 
		    }
		    
		} else {
		    System.err.println("DPAP load Warning: unknown parameter " +
				       entries[0] + ", value = " + entries[1]);
		}
	    }
	}
	catch (Exception ex){
	    ex.printStackTrace();
	    System.exit(1);
	}
    }

    /*****************************************************************/
    /*                              Main                             */
    /*****************************************************************/
    
    public static void main(String args[]) {
	DPAffinityPropagation dpap = 
	    new DPAffinityPropagation(100);
	
	String baseFilename =
	    "data/exemplar_model/ex2_a1_d5_g0100_s50_sgiven50_id2";
	String similaritiesFile = baseFilename + "_similarities.txt";
	String labelsFile = baseFilename + "_labels.txt";

	dpap.LoadSimilaritiesFromFile(similaritiesFile);
	dpap.LoadTrueLabelsFromFile(labelsFile);

	int[] trueLabels = dpap.TrueLabels(); 

	while (dpap.Iteration() < 500 && !dpap.HasConverged()) {
	    dpap.PassMessages();

	    ArrayList<Integer> exemplars = dpap.CurrentExemplars();
	    System.out.println("Exemplars: " + exemplars);
	    int[] assignment = dpap.CurrentAssignments();
	    double trueScore = dpap.RandIndex(assignment, trueLabels);
	    System.out.println("Rand index: " + trueScore);
	    dpap.PrintAssignment(assignment);
	}

	System.out.println("Beliefs:");
	dpap.PrintBeliefs();
	ArrayList<Integer> exemplars = dpap.CurrentExemplars();
	System.out.println("Exemplars: " + exemplars);
	dpap.PrintAssignment(dpap.CurrentAssignments());
	dpap.PrintStats();

	boolean doFineTuning = false;
	if (doFineTuning) {
	    for (int i = 0; i < 5; i++) {
		System.out.println("********************************");
	    }

	    // Run the Maximizing Gibbs sampler as the final step
	    MaximizingGibbsSampler mg = new MaximizingGibbsSampler(100);
	    mg.LoadSimilaritiesFromFile(similaritiesFile);
	    mg.InitializeToAssignment(dpap.CurrentAssignments());
	    mg.MAPInference();
	    exemplars = mg.CurrentExemplars();
	    System.out.println("Exemplars: " + exemplars);
	    int[] assignment = mg.CurrentAssignments();
	    mg.PrintAssignment(assignment);
	    mg.PrintStats();
	}

    }
}
