/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.evaluation.results;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import networks.computation.evaluation.results.RegressionResults;
import networks.computation.evaluation.results.Result;
import networks.computation.evaluation.values.ScalarValue;
import networks.computation.evaluation.values.Value;
import settings.Settings;
import utils.exporting.Exporter;

public class ClassificationResults
extends RegressionResults {
    static Value oneHalf = new ScalarValue(0.5);
    private Double accuracy;
    private Double precision;
    private Double recall;
    private Double f_Measure;
    private Double majorityErr;
    private Value dispersion;
    private int goodCount;
    private int zeroCount;
    private int oneCount;
    public Value bestThreshold;
    public Double bestAccuracy;

    public ClassificationResults(List<Result> outputs, Settings settings) {
        super(outputs, settings);
    }

    @Override
    public boolean recalculate() {
        this.error = this.calculateErrorValue();
        this.loadBinaryMetrics(this.evaluations);
        if (this.settings.calculateBestThreshold) {
            this.getBestAccuracyThreshold(this.evaluations);
        }
        return true;
    }

    public Value calculateErrorValue() {
        ArrayList<Value> errors = new ArrayList<Value>(this.evaluations.size());
        for (Result evaluation : this.evaluations) {
            errors.add(evaluation.errorValue());
        }
        return this.aggregationFcn.evaluate(errors);
    }

    private void loadBinaryMetrics(List<Result> evaluations) {
        this.goodCount = 0;
        this.zeroCount = 0;
        this.oneCount = 0;
        ScalarValue zeroSum = new ScalarValue(0.0);
        ScalarValue oneSum = new ScalarValue(0.0);
        for (Result evaluation : evaluations) {
            if (evaluation.target.greaterThan(oneHalf)) {
                ++this.oneCount;
                ((Value)oneSum).incrementBy(evaluation.output);
                if (!evaluation.output.greaterThan(oneHalf)) continue;
                ++this.goodCount;
                continue;
            }
            ++this.zeroCount;
            ((Value)zeroSum).incrementBy(evaluation.output);
            if (!oneHalf.greaterThan(evaluation.output)) continue;
            ++this.goodCount;
        }
        this.majorityErr = (double)Math.max(this.zeroCount, this.oneCount) / (double)evaluations.size();
        this.dispersion = ((Value)oneSum).elementTimes((Value)new ScalarValue(1.0 / (double)this.oneCount)).minus(((Value)zeroSum).elementTimes((Value)new ScalarValue(1.0 / (double)this.zeroCount)));
        this.accuracy = (double)this.goodCount / (double)evaluations.size();
    }

    public Double getBestAccuracy(List<Result> evaluations, Value trainedThreshold) {
        int TP = 0;
        int TN = 0;
        for (Result evaluation : evaluations) {
            if (evaluation.output.greaterThan(trainedThreshold) && evaluation.target.greaterThan(trainedThreshold)) {
                ++TP;
                continue;
            }
            if (!trainedThreshold.greaterThan(evaluation.output) || !trainedThreshold.greaterThan(evaluation.target)) continue;
            ++TN;
        }
        this.bestThreshold = trainedThreshold;
        this.bestAccuracy = (double)(TP + TN) / (double)evaluations.size();
        return this.bestAccuracy;
    }

    public void getBestAccuracyThreshold(List<Result> evaluations) {
        Collections.sort(evaluations);
        double allCount = evaluations.size();
        int cumNegCount = 0;
        int cumPosCount = 0;
        double bestCumErr = evaluations.size();
        int bestIndex = -1;
        int i = 0;
        while (i < evaluations.size()) {
            double cumErr = (double)(cumPosCount + this.zeroCount - cumNegCount) / allCount;
            if (cumErr < bestCumErr) {
                bestIndex = i;
                bestCumErr = cumErr;
            }
            do {
                Result evaluation = evaluations.get(i);
                if (evaluation.target.greaterThan(oneHalf)) {
                    ++cumPosCount;
                    continue;
                }
                ++cumNegCount;
            } while (++i < evaluations.size() && evaluations.get((int)i).output.equals(evaluations.get((int)(i - 1)).output));
        }
        this.bestThreshold = evaluations.get((int)bestIndex).output;
        if (bestIndex - 1 >= 0) {
            this.bestThreshold = this.bestThreshold.plus(evaluations.get((int)(bestIndex - 1)).output).times(oneHalf);
        }
        this.bestAccuracy = 1.0 - bestCumErr;
    }

    @Override
    public String toString() {
        return this.toString(null);
    }

    @Override
    public String toString(Settings settings) {
        StringBuilder sb = new StringBuilder();
        if (this.accuracy != null) {
            sb.append("accuracy: " + Settings.shortNumberFormat.format(this.accuracy * 100.0) + "% (maj. " + Settings.shortNumberFormat.format(this.majorityErr * 100.0) + "%)");
        }
        if (this.bestAccuracy != null) {
            sb.append("(best thresh acc: " + Settings.shortNumberFormat.format(this.bestAccuracy * 100.0) + "%)");
        }
        if (this.dispersion != null) {
            sb.append(", disp: " + this.dispersion.toString());
        }
        if (this.error != null) {
            if (settings == null) {
                sb.append(", error: ").append(this.error.toString());
            } else {
                String errAggfcn = settings.errorAggregationFcn.toString();
                String errFcn = settings.errorFunction.toString();
                String errString = errAggfcn + "(" + errFcn + ")";
                sb.append(", error: ").append(errString).append(" = ").append(this.error.toString());
            }
        }
        return sb.toString();
    }

    @Override
    public void export(Exporter exporter) {
        exporter.export(this);
    }
}

