/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.learning.results;

import cz.cvut.fel.ida.algebra.utils.MathUtils;
import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.learning.crossvalidation.MeanStdResults;
import cz.cvut.fel.ida.learning.results.RegressionResults;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.setup.Settings;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class ClassificationResults
extends RegressionResults {
    private static final Logger LOG = Logger.getLogger(ClassificationResults.class.getName());
    static Value oneHalf = new ScalarValue(0.5);
    private Double accuracy;
    public Double majorityErr;
    public Double dispersion;
    private int goodCount;
    protected int zeroCount;
    protected int oneCount;

    public ClassificationResults(List<Result> outputs, Settings settings) {
        super(outputs, settings);
    }

    protected ClassificationResults(Value error, Double accuracy, Double majorityErr, Double dispersion) {
        super(error);
        this.accuracy = accuracy;
        this.majorityErr = majorityErr;
        this.dispersion = dispersion;
    }

    @Override
    public boolean recalculate() {
        this.error = this.calculateErrorValue();
        this.loadBasicCounts(this.evaluations);
        this.loadBinaryMetrics(this.evaluations);
        return true;
    }

    public Value calculateErrorValue() {
        ArrayList<Value> errors = new ArrayList<Value>(this.evaluations.size());
        for (Result evaluation : this.evaluations) {
            errors.add(evaluation.errorValue());
        }
        return this.aggregationFcn.evaluate(errors);
    }

    public static MeanStdResults aggregateClassifications(List<ClassificationResults> resultsList) {
        List<Value> errors = resultsList.stream().map(res -> res.error).collect(Collectors.toList());
        if (errors.isEmpty() || errors.get(0) == null) {
            return null;
        }
        Value meanError = MathUtils.getMeanValue(errors);
        Value stdError = MathUtils.getStd(errors, meanError);
        List<Double> accuracies = resultsList.stream().map(res -> res.accuracy).collect(Collectors.toList());
        Double meanAcc = MathUtils.getMean(accuracies);
        Double stdAcc = MathUtils.getStd(accuracies, meanAcc);
        List<Double> dispersions = resultsList.stream().map(res -> res.dispersion).collect(Collectors.toList());
        Double meanDisp = MathUtils.getMean(dispersions);
        Double stdDisp = MathUtils.getStd(dispersions, meanDisp);
        List<Double> majorErrs = resultsList.stream().map(res -> res.majorityErr).collect(Collectors.toList());
        Double meanMajErr = MathUtils.getMean(majorErrs);
        Double stdMajErr = MathUtils.getStd(majorErrs, meanMajErr);
        ClassificationResults mean = new ClassificationResults(meanError, meanAcc, meanMajErr, meanDisp);
        ClassificationResults std = new ClassificationResults(stdError, stdAcc, stdMajErr, stdDisp);
        return new MeanStdResults(mean, std);
    }

    private void loadBasicCounts(List<Result> evaluations) {
        this.zeroCount = 0;
        this.oneCount = 0;
        for (Result evaluation : evaluations) {
            if (evaluation.getTarget().greaterThan(oneHalf)) {
                ++this.oneCount;
                continue;
            }
            ++this.zeroCount;
        }
        this.majorityErr = (double)Math.max(this.zeroCount, this.oneCount) / (double)evaluations.size();
    }

    private void loadBinaryMetrics(List<Result> evaluations) {
        this.goodCount = 0;
        ScalarValue zeroSum = new ScalarValue(0.0);
        ScalarValue oneSum = new ScalarValue(0.0);
        for (Result evaluation : evaluations) {
            if (evaluation.getTarget().greaterThan(oneHalf)) {
                ((Value)oneSum).incrementBy(evaluation.getOutput());
                if (!evaluation.getOutput().greaterThan(oneHalf)) continue;
                ++this.goodCount;
                continue;
            }
            ((Value)zeroSum).incrementBy(evaluation.getOutput());
            if (!oneHalf.greaterThan(evaluation.getOutput())) continue;
            ++this.goodCount;
        }
        Value disp = ((Value)oneSum).elementTimes((Value)new ScalarValue(1.0 / (double)this.oneCount)).minus(((Value)zeroSum).elementTimes((Value)new ScalarValue(1.0 / (double)this.zeroCount)));
        this.dispersion = ((ScalarValue)disp).value;
        this.accuracy = (double)this.goodCount / (double)evaluations.size();
    }

    @Override
    public String toString() {
        return this.toString(null);
    }

    @Override
    public String toString(Settings settings) {
        StringBuilder sb = new StringBuilder();
        if (this.accuracy != null) {
            sb.append("accuracy: " + Settings.shortNumberFormat.format(this.accuracy * 100.0) + "%");
        }
        if (this.dispersion != null) {
            sb.append(", disp: " + this.dispersion.toString());
        }
        if (this.error != null) {
            if (settings == null) {
                sb.append(", error: ").append(this.error.toString());
            } else {
                String errAggfcn = settings.errorAggregationFcn.toString();
                String errFcn = settings.errorFunction.toString();
                String errString = errAggfcn + "(" + errFcn + ")";
                sb.append(", error: ").append(errString).append(" = ").append(this.error.toString());
            }
        }
        return sb.toString();
    }
}

