package ceka.KFNN;

import ceka.core.Dataset;
import ceka.core.Example;
import weka.core.matrix.Matrix;

public class Mdistance {
	
	// һ˽бʾݼ
	private Matrix matrix;
	private int numInstances;
	private int numAttributes;
	private Matrix meanAttributes;
	private Matrix Covariances;
	private Matrix ACovariances;
	
    /**
     * ݼתΪ󣬲ֵЭ
     *
     * @param dataset ݼ
     * @return 
     */
	public Mdistance(Dataset dataset) {
		numInstances = dataset.getExampleSize();
		numAttributes = dataset.numAttributes();
		matrix = new Matrix(numInstances, numAttributes-1);
		meanAttributes = new Matrix(1, numAttributes-1);
		for(int i=0; i<numInstances; i++) {
			for(int j=0; j<numAttributes-1; j++) {
				matrix.set(i, j, dataset.instance(i).value(j));
				meanAttributes.set(0, j, meanAttributes.get(0, j) + dataset.instance(i).value(j));
			}
		}
		// ľֵ
		for(int j=0; j<numAttributes-1; j++) {
			meanAttributes.set(0, j, meanAttributes.get(0, j) / numInstances);	
		}
		// Э
		Covariances = new Matrix(numAttributes-1, numAttributes-1);
		ACovariances = new Matrix(numAttributes-1, numAttributes-1);
		if(numInstances == 1) {
			for (int i = 0; i<numAttributes-1; i++) {
				Covariances.set(i, i, 1);
				ACovariances.set(i, i, 1);
			}
		}
		else {
			for (int i = 0; i<numAttributes-1; i++) {
	            for (int j = 0; j < numAttributes-1; j++) {
	            	double temp1 = 0;
	            	for (int k = 0; k < numInstances; k++) {
	            		temp1 += (matrix.get(k, i) - meanAttributes.get(0, i)) * (matrix.get(k, j) - meanAttributes.get(0, j));
	            	}
	                Covariances.set(i, j, temp1 / (numInstances - 1));
	            }
	        }
	        // Э
	        boolean success = true;
		    double ridge = 0.1;
		    Matrix ss = Covariances.copy();
		    do {
		    	// Խ߼һֵ֤
		    	for (int i = 0; i < numAttributes-1; i++)
		    		ss.set(i, i, ss.get(i, i) + ridge);
		    	try {
		    		ACovariances = ss.inverse();
		    		success = true;
		    	} 
		    	catch (Exception ex) {
		    		ridge *= 10;
		    		success = false;
		    	}
		    } while (!success);
		}
	}

    /**
     * Ͼ롣
     *
     * @param point    Ŀ
     * @param dataset  ݼ
     * @return Ͼ
     */
    public double calMDistance(Example example) {
    	Matrix temp = new Matrix(1, numAttributes-1);
    	for (int i=0; i < numAttributes-1; i++) {
        	temp.set(0, i, example.value(i));
        }
        Matrix result1 = temp.minus(meanAttributes);
        Matrix result2 = result1.times(ACovariances);
        result2 = result2.times(result1.transpose());
        if(result2.get(0, 0) != 0)
        	return java.lang.Math.sqrt((double)(result2.get(0, 0)));
        else
        	return 0;
    }
    
    /**
     * Ͼ롣
     *
     * @param point    Ŀ
     * @param dataset  ݼ
     * @return Ͼ
     */
    public double calMDistance(Example example1, Example example2) {
    	Matrix temp1 = new Matrix(1, numAttributes-1);
    	Matrix temp2 = new Matrix(1, numAttributes-1);
    	for (int i=0; i < numAttributes-1; i++) {
        	temp1.set(0, i, example1.value(i));
        	temp2.set(0, i, example2.value(i));
        }
        Matrix result1 = temp1.minus(temp2);
        Matrix result2 = result1.times(ACovariances);
        result2 = result2.times(result1.transpose());
        if(result2.get(0, 0) != 0)
        	return java.lang.Math.sqrt((double)(result2.get(0, 0)));
        else
        	return 0;
    }
}