package jsat.classifiers.boosting;

import java.util.*;
import jsat.classifiers.*;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

/**
 * Modest Ada Boost is a generalization of Discrete Ada Boost that attempts to 
 * reduce the generalization error and avoid over-fitting. Empirically, 
 * ModestBoost usually maintains a higher training-set error, and may take more
 * iterations to obtain the same test set error as other algorithms, but doesn't
 * not increase as much after it reaches the minimum error - which should make 
 * it easier to obtain the higher accuracy.
 * <br>
 * See: <br>
 * Vezhnevets, A.,&amp;Vezhnevets, V. (2005). <i>“Modest AdaBoost” – Teaching 
 * AdaBoost to Generalize Better</i>. GraphiCon. Novosibirsk Akademgorodok, 
 * Russia. Retrieved from 
 * <a href="http://www.inf.ethz.ch/personal/vezhneva/Pubs/ModestAdaBoost.pdf">
 * here</a>
 * 
 * @author Edward Raff
 */
public class ModestAdaBoost  implements Classifier, Parameterized, BinaryScoreClassifier
{

    private static final long serialVersionUID = 8223388561185098909L;
    private Classifier weakLearner;
    private int maxIterations;
    /**
     * The list of weak hypothesis
     */
    protected List<Classifier> hypoths;
    /**
     * The weights for each weak learner
     */
    protected List<Double> hypWeights;
    protected CategoricalData predicting;

    /**
     * Creates a new ModestBoost learner
     * @param weakLearner the weak learner to use
     * @param maxIterations the maximum number of boosting iterations
     */
    public ModestAdaBoost(Classifier weakLearner, int maxIterations)
    {
        setWeakLearner(weakLearner);
        setMaxIterations(maxIterations);
    }
    
    /**
     * Copy constructor
     * @param toClone the object to clone
     */
    protected ModestAdaBoost(ModestAdaBoost toClone)
    {
        this(toClone.weakLearner.clone(), toClone.maxIterations);
        if(toClone.hypWeights != null)
        {
            this.hypWeights = new DoubleList(toClone.hypWeights);
            this.hypoths = new ArrayList<Classifier>(toClone.maxIterations);
            for(Classifier weak : toClone.hypoths)
                this.hypoths.add(weak.clone());
            this.predicting = toClone.predicting.clone();
        }
    }
    
    /**
     * 
     * @return a list of the models that are in this ensemble. 
     */
    public List<Classifier> getModels()
    {
        return Collections.unmodifiableList(hypoths);
    }
    
    /**
     * 
     * @return a list of the models weights that are in this ensemble. 
     */
    public List<Double> getModelWeights()
    {
        return Collections.unmodifiableList(hypWeights);
    }
    
    /**
     * Returns the maximum number of iterations used
     * @return the maximum number of iterations used
     */
    public int getMaxIterations()
    {
        return maxIterations;
    }

    /**
     * Sets the maximal number of boosting iterations that may be performed 
     * @param maxIterations the maximum number of iterations
     */
    public void setMaxIterations(int maxIterations)
    {
        if(maxIterations < 1)
            throw new IllegalArgumentException("Iterations must be positive, not " + maxIterations);
        this.maxIterations = maxIterations;
    }

    /**
     * Returns the weak learner currently being used by this method. 
     * @return the weak learner currently being used by this method. 
     */
    public Classifier getWeakLearner()
    {
        return weakLearner;
    }

    /**
     * Sets the weak learner used during training. 
     * @param weakLearner the weak learner to use
     */
    public void setWeakLearner(Classifier weakLearner)
    {
        if(!weakLearner.supportsWeightedData())
            throw new IllegalArgumentException("WeakLearner must support weighted data to be boosted");
        this.weakLearner = weakLearner;
    }
    
    @Override
    public double getScore(DataPoint dp)
    {
        double score = 0;
        for(int i = 0; i < hypoths.size(); i++)
            score += (hypoths.get(i).classify(dp).getProb(1)*2-1)*hypWeights.get(i);
        return score;
    }

    @Override
    public CategoricalResults classify(DataPoint data)
    {
        if(predicting == null)
            throw new RuntimeException("Classifier has not been trained yet");
        
        CategoricalResults cr = new CategoricalResults(predicting.getNumOfCategories());
        
        double score =  getScore(data);
        if(score < 0)
            cr.setProb(0, 1.0);
        else
            cr.setProb(1, 1.0);
        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel)
    {
        predicting = dataSet.getPredicting();
        hypWeights = new DoubleList(maxIterations);
        hypoths = new ArrayList<Classifier>(maxIterations);
        final int N = dataSet.size();
        
        double[] D_inv = new double[N];
        double[] D = new double[N];
        
        ClassificationDataSet cds = dataSet.shallowClone();
        Arrays.fill(D, 1.0/N);
        for(int i = 0; i < N; i++)
            cds.setWeight(i, D[0]);//Scaled, they are all 1 
        double weightSum = 1;
        
        double[] H_cur = new double[N];
        
        for(int t = 0; t < maxIterations; t++)
        {
            Classifier weak = weakLearner.clone();
            weak.train(cds, parallel);
            
            double invSum = 0;
            for(int i = 0; i < N; i++)
                invSum += (D_inv[i] = 1-D[i]);
            
            for(int i = 0; i < N; i++)
                D_inv[i] /= invSum;
            double p_d = 0, p_id = 0, n_d = 0, n_id = 0;
            
            for(int i = 0; i < N; i++)
            {
                H_cur[i] = (weak.classify(cds.getDataPoint(i)).getProb(1)*2-1);
                double outPut = Math.signum(H_cur[i]);
                int c = cds.getDataPointCategory(i);
                if(c == 1)//positive example case
                {
                    p_d  += outPut * D[i];
                    p_id += outPut * D_inv[i];
                }
                else
                {
                    n_d  += outPut * D[i];
                    n_id += outPut * D_inv[i];
                }
                
            }
            
            double alpha_m = p_d * (1 - p_id) - n_d * (1 - n_id); 
            
            if(Math.signum(alpha_m) != Math.signum(p_d-n_d) || Math.abs((p_d - n_d)) < 1e-6 || alpha_m <= 0)
                return;
            
            weightSum = 0;
            for(int i = 0; i < N; i++)
            {
                double w_i = cds.getWeight(i);
                int y_i = cds.getDataPointCategory(i)*2-1;
                w_i *= Math.exp(-y_i*alpha_m*H_cur[i]);
                if(Double.isInfinite(w_i))
                    w_i = 1;//Let it grow back
                else if(w_i <= 0)
                    w_i = 1e-3/N;//Dont let it go quit to zero
                weightSum += w_i;
                cds.setWeight(i, w_i);
            }
            
            for(int i = 0; i < N; i++)
		cds.setWeight(i, Math.max(cds.getWeight(i)/weightSum, 1e-10));
            
            hypWeights.add(alpha_m);
            hypoths.add(weak);
        }
    }

    @Override
    public boolean supportsWeightedData()
    {
        return false;
    }

    @Override
    public ModestAdaBoost clone()
    {
        return new ModestAdaBoost(this);
    }
}
