/* KFNN
 * */
package ceka.KFNN;

import java.util.Arrays;

import ceka.core.Dataset;
import ceka.core.Example;
import ceka.core.Label;
import ceka.core.MultiNoisyLabelSet;
import weka.core.Utils;

public class KFNN {

	public static final String NAME = "G";
	
	private int numExample = 0;
	private int numCategory = 0;
	
	public void doInference(Dataset dataset) throws Exception {
		// ÏÈ¶¨Òå¼¸¸ö±äÁ¿
		numCategory = dataset.getCategorySize();
		numExample = dataset.getExampleSize();
		
		double[][] probs = new double[numExample][numCategory];
		double[] count = new double[numCategory];
		for(int i=0; i<numExample; i++) {
			// ¼ÆËã¸ÅÂÊ·Ö²¼
			MultiNoisyLabelSet multipleNoisyLabelSet = dataset.getExampleByIndex(i).getMultipleNoisyLabelSet(0);
			int labelSize = multipleNoisyLabelSet.getLabelSetSize();
			for(int k=0; k<labelSize; k++) {
				probs[i][multipleNoisyLabelSet.getLabel(k).getValue()] += 1.0 / labelSize;
			}
			int tempIndex = Utils.maxIndex(probs[i]);
			count[tempIndex] += 1;
			Label label = new Label(null, String.valueOf(tempIndex), dataset.getExampleByIndex(i).getId(), NAME);
			dataset.getExampleByIndex(i).setIntegratedLabel(label);
		}
		
		// ÎÈ¶¨ÐÔÉèÖÃ£¬·ÀÖ¹Êý¾Ý¼¯Ì«²»Æ½ºâ»òÕß³öÏÖÎª0Àà±ð
		boolean mark = false;
		if(count[Utils.minIndex(count)] < 10) {
			mark = true;
		}
		
		// »®·Ö´Ø
		Dataset[] subDatasets = new Dataset[numCategory];
		for(int k=0; k<numCategory; k++) {
			subDatasets[k] = new Dataset(dataset,0);
		}
		for(int i=0; i<dataset.getExampleSize(); i++) {
			int label = dataset.getExampleByIndex(i).getIntegratedLabel().getValue();
			subDatasets[label].addExample(dataset.getExampleByIndex(i));
		}
		
		// ¹¹½¨ÂíÊÏ¾àÀëµÄ¼ÆËã·½°¸
		Mdistance[] tempMdistances = new Mdistance[numCategory];
		for(int k=0; k<numCategory; k++) {
			tempMdistances[k] = new Mdistance(subDatasets[k]);
		}
		
		// ¼ÆËã¸ÅÂÊ
		double[][] distances = new double[numExample][numExample];
		double[][] tempDis = new double[numExample][numCategory];
		double[][] probs2 = new double[numExample][numCategory];
		for(int i=0; i<numExample; i++) {
			Example e1 = dataset.getExampleByIndex(i);
			// ÏÈÇóÊµÀýµ½¸÷¸ö·Ö²¼µÄ¾àÀë
			double tempMin = Double.MAX_VALUE;
			double tempMax = Double.MIN_VALUE;
			for(int k=0; k<numCategory; k++) {
				if(subDatasets[k].getExampleSize() > 0) {
					tempDis[i][k] = tempMdistances[k].calMDistance(e1);
					if(tempDis[i][k] > tempMax)
						tempMax = tempDis[i][k];
					if(tempDis[i][k] < tempMin)
						tempMin = tempDis[i][k];
				}
				else
					tempDis[i][k] = 0;
			}
			for(int k=0; k<numCategory; k++) {
				if(tempDis[i][k] > 0)
					probs2[i][k] = (tempMax - tempDis[i][k]) / (tempMax - tempMin);
				else
					tempDis[i][k] = Double.MAX_VALUE;
			}
			
			for(int j=i+1; j<numExample; j++) {
				Example e2 = dataset.getExampleByIndex(j);
				double temp = 0;
				for(int k=0; k<numCategory; k++) {
					temp += tempMdistances[k].calMDistance(e1, e2);
				}
				distances[i][j] = temp;
				distances[j][i] = temp;
			}
		}
		
		// ±ê¼ÇÔöÇ¿
		if(!mark) {
			for(int i=0; i<numExample; i++) {
				for(int k=0; k<numCategory; k++) {
					probs[i][k] += probs2[i][k];
				}
				Utils.normalize(probs[i]);
			}
		}
		
		// ÍÆ¶Ï
		for(int i=0; i<numExample; i++) {
			Example example = dataset.getExampleByIndex(i);
			MultiNoisyLabelSet mnls = example.getMultipleNoisyLabelSet(0);
			
			int[] indexs = Utils.sort(distances[i]);
			double[] class_counts = new double[numCategory];
			double[][] maxPro = new double[numExample][numCategory];
			// ¾ßÌå¼ÆËã
			for(int j=0; j<numExample; j++) {
				Example neighbor = dataset.getExampleByIndex(indexs[j]);

				// ÏÈ¼ÆËãÈ¨ÖØ
				double tem1 = 0;
				double tem2 = 0;
				for(int k = 0; k < mnls.getLabelSetSize(); k++){
					Label label = neighbor.getNoisyLabelByWorkerId(mnls.getLabel(k).getWorkerId());
					if(label != null){
						tem1++;
						if(label.getValue() == mnls.getLabel(k).getValue()){
							tem2 += 1.0;
						}
					}
				}
				double weight = 0.0; 
				
				//¿¼ÂÇµ±·ÖÄ¸²»Îª0µÄÇé¿ö
				if(tem1 == 0){
					weight = 0;
				}else{
					weight = tem2 / tem1;
				}

				for(int k=0; k<numCategory; k++) {
					class_counts[k] += weight * probs[indexs[j]][k];
					maxPro[j][k] = class_counts[k];
				}
				if(Utils.sum(maxPro[j]) != 0)
					Utils.normalize(maxPro[j]);
			}
		    // ÅÐ¶Ï--ÕÒ·çÏÕ×îÐ¡µÄµØ·½£¨²îÖµ×î´ó´¦£©
			double q = 0.1; // process noise covariance
	        double r = 1; // measurement noise covariance
	        double p = 1; // estimation error covariance
	        double initialValue = 0; // initial estimated value
	        Kalman kalmanFilter = new Kalman(q, r, p, initialValue);
			
			double[] temp_count2 = new double[numExample];
			double[] temp_count3 = new double[numExample];
			int[] tempSort = new int[numCategory];
			for(int j=0; j<numExample; j++) {
				tempSort = Utils.sort(maxPro[j]);
				temp_count2[j] = maxPro[j][tempSort[numCategory-1]] - maxPro[j][tempSort[numCategory-2]];
				if(j==0)
					temp_count3[j] = temp_count2[j];
				else
					temp_count3[j] = kalmanFilter.update(temp_count2[j], 0);
			}
			int maxIndex = Utils.maxIndex(Arrays.copyOfRange(temp_count3, 0, (int)(numExample * 0.5 / numCategory)));
			Label integratedL = new Label(null, String.valueOf(Utils.maxIndex(maxPro[maxIndex])), example.getId(), NAME);
			example.setIntegratedLabel(integratedL);
		}
		dataset.assignIntegeratedLabel2WekaInstanceClassValue();
	}
}
