﻿namespace Robotless.Modules.Mocking.Learning.Metrics;

public class AccuracyMetrics<TResult> : TrainerMetrics<TResult>
{
    private int _totalCount;
    private int _correctCount;

    public double Accuracy => (double)_correctCount / _totalCount;

    public Func<TResult, TResult, bool> Comparer { get; }

    public AccuracyMetrics(Func<TResult, TResult, bool>? comparer = null)
    {
        Comparer = comparer ?? DefaultComparer;
        
        return;
        
        bool DefaultComparer(TResult expectedResult, TResult actualResult)
            => expectedResult != null ? expectedResult.Equals(actualResult) : actualResult == null;
    }

    public override void Report(TResult expectedResult, TResult actualResult)
    {
        ++_totalCount;

        if (Comparer(expectedResult, actualResult))
        {
            ++_correctCount;
        }
    }

    public override string ToText()
    {
        return Accuracy.ToString("P4");
    }
}

public static class AccuracyMetricsExtensions
{
    public static Dictionary<string,Func<TrainerMetrics<TResult>>> UseAccuracy<TResult>(
        this Dictionary<string,Func<TrainerMetrics<TResult>>> metrics, 
        Func<TResult, TResult, bool>? comparer = null)
    {
        metrics["Accuracy"] = () => new AccuracyMetrics<TResult>(comparer);
        return metrics;
    }
}