/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.learning.results;

import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.learning.results.ClassificationResults;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.learning.results.Results;
import cz.cvut.fel.ida.learning.results.metrics.AUC;
import cz.cvut.fel.ida.setup.Settings;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;

public class DetailedClassificationResults
extends ClassificationResults {
    private static final Logger LOG = Logger.getLogger(ClassificationResults.class.getName());
    public Value bestThreshold = new ScalarValue(0.5);
    public Double bestAccuracy;
    private Double precision;
    private Double recall;
    private Double f_Measure;
    public Double AUCroc;
    private Double AUCrocEmpirical;
    private Double AUCpr;

    public DetailedClassificationResults(List<Result> outputs, Settings aggregationFcn) {
        super(outputs, aggregationFcn);
    }

    @Override
    public boolean betterThan(Results results, Settings.ModelSelection criterion) {
        DetailedClassificationResults other = (DetailedClassificationResults)results;
        switch (criterion) {
            case AUCpr: {
                return this.AUCpr > other.AUCpr;
            }
            case AUCroc: {
                return this.AUCroc > other.AUCroc;
            }
            case ACCURACY: {
                return this.bestAccuracy > other.bestAccuracy;
            }
            case DISPERSION: {
                return this.dispersion > other.dispersion;
            }
        }
        return super.betterThan(other, criterion);
    }

    @Override
    public boolean recalculate() {
        super.recalculate();
        if (this.settings.forceDetailedResults) {
            this.computeDetailedStats(this.evaluations);
        }
        return true;
    }

    public void computeDetailedStats(List<Result> evaluations) {
        if (this.settings.alternativeAUC) {
            this.AUCrocEmpirical = this.calculateAUCsmaller(evaluations);
        }
        try {
            this.setFullAUC(evaluations);
        }
        catch (Exception e) {
            LOG.warning("Could not calculate AUC stats");
        }
    }

    public Double computeBestAccuracy(List<Result> evaluations, Value trainedThreshold) {
        int TP = 0;
        int TN = 0;
        for (Result evaluation : evaluations) {
            if (evaluation.getOutput().greaterThan(trainedThreshold) && evaluation.getTarget().greaterThan(trainedThreshold)) {
                ++TP;
                continue;
            }
            if (!trainedThreshold.greaterThan(evaluation.getOutput()) || !trainedThreshold.greaterThan(evaluation.getTarget())) continue;
            ++TN;
        }
        this.bestThreshold = trainedThreshold;
        this.bestAccuracy = (double)(TP + TN) / (double)evaluations.size();
        return this.bestAccuracy;
    }

    public Value computeBestAccuracyThreshold(List<Result> evaluations) {
        Collections.sort(evaluations);
        double allCount = evaluations.size();
        int cumNegCount = 0;
        int cumPosCount = 0;
        double bestCumErr = evaluations.size();
        int bestIndex = -1;
        int i = 0;
        while (i < evaluations.size()) {
            double cumErr = (double)(cumPosCount + this.zeroCount - cumNegCount) / allCount;
            if (cumErr < bestCumErr) {
                bestIndex = i;
                bestCumErr = cumErr;
            }
            do {
                Result evaluation;
                if ((evaluation = evaluations.get(i)).getTarget().greaterThan(oneHalf)) {
                    ++cumPosCount;
                    continue;
                }
                ++cumNegCount;
            } while (++i < evaluations.size() && evaluations.get(i).getOutput().equals(evaluations.get(i - 1).getOutput()));
        }
        try {
            this.bestThreshold = evaluations.get(bestIndex).getOutput();
        }
        catch (IndexOutOfBoundsException e) {
            this.bestThreshold = evaluations.get(0).getOutput();
        }
        if (bestIndex - 1 >= 0) {
            this.bestThreshold = this.bestThreshold.plus(evaluations.get(bestIndex - 1).getOutput()).times(oneHalf);
        }
        this.bestAccuracy = 1.0 - bestCumErr;
        return this.bestThreshold;
    }

    public double calculateAUCsmaller(List<Result> evaluations) {
        double pos = this.oneCount;
        double neg = this.zeroCount;
        Collections.sort(evaluations);
        double[] ranks = new double[evaluations.size()];
        for (int i = 0; i < evaluations.size(); ++i) {
            int j;
            if (i == evaluations.size() - 1 || !evaluations.get(i).getOutput().equals(evaluations.get(i + 1).getOutput())) {
                ranks[i] = i + 1;
                continue;
            }
            for (j = i + 1; j < evaluations.size() && evaluations.get(j).getOutput() == evaluations.get(i).getOutput(); ++j) {
            }
            double r = (double)(i + 1 + j) / 2.0;
            for (int k = i; k < j; ++k) {
                ranks[k] = r;
            }
            i = j - 1;
        }
        double auc = 0.0;
        for (int i = 0; i < evaluations.size(); ++i) {
            if (!evaluations.get(i).getTarget().greaterThan(oneHalf)) continue;
            auc += ranks[i];
        }
        auc = (auc - pos * (pos + 1.0) / 2.0) / (pos * neg);
        return auc;
    }

    public void setFullAUC(List<Result> evaluations) {
        AUC auc = new AUC(evaluations);
        this.AUCroc = auc.getAUCroc();
        this.AUCpr = auc.getAUCpr();
    }

    @Override
    public String toString() {
        return super.toString();
    }

    @Override
    public String toString(Settings settings) {
        String s = super.toString(settings);
        StringBuilder sb = new StringBuilder(s);
        if (this.bestAccuracy != null) {
            sb.append(", (best thresh acc: " + Settings.shortNumberFormat.format(this.bestAccuracy * 100.0) + "%)");
        }
        sb.append(" (maj. " + Settings.shortNumberFormat.format(this.majorityErr * 100.0) + "%)");
        if (this.AUCroc != null) {
            sb.append(", (AUC-ROC: " + Settings.detailedNumberFormat.format(this.AUCroc) + ")");
        }
        if (this.AUCrocEmpirical != null) {
            sb.append(", (AUC-ROC [empirical]: " + Settings.detailedNumberFormat.format(this.AUCrocEmpirical) + ")");
        }
        if (this.AUCpr != null) {
            sb.append(", (AUC-PR: " + Settings.detailedNumberFormat.format(this.AUCpr) + ")");
        }
        return sb.toString();
    }
}

