﻿using System.Diagnostics;
using System.Linq.Expressions;
using Microsoft.Extensions.Logging;
using Robotless.Kernel;
using Robotless.Modules.AiAgent.Messages;
using Robotless.Modules.Injecting;
using Robotless.Modules.Logging;
using Robotless.Modules.Utilities;
using Spectre.Console;

namespace Robotless.Modules.Mocking.Learning;

/// <summary>
/// A trainer that provides methods to train mock functions.
/// </summary>
/// <typeparam name="TDataEntry">Type of the data entry in the dataset.</typeparam>
/// <typeparam name="TMockDelegate">Delegate type of the mock functions to train.</typeparam>
/// <typeparam name="TResult">Type of the return value of the mock functions.</typeparam>
public class Trainer<TMockDelegate, TDataEntry, TResult> :
    IIdentifiable, IDisposable
    where TMockDelegate : Delegate
{
    public Guid Identifier { get; init; } = Guid.CreateVersion7();

    /// <summary>
    /// Create a trainer for the specified mock function and dataset.
    /// </summary>
    /// <param name="mock">Mock function to train.</param>
    /// <param name="data">Data to use.</param>
    /// <param name="argumentsMapping">Map an entry into input arguments.</param>
    /// <param name="resultMapping">Map an entry into the return value.</param>
    /// <param name="resultVerifier">Verifier to verify if a result should be considered as 'correct'.</param>
    public Trainer(
        MockDelegate<TMockDelegate> mock,
        IEnumerable<TDataEntry> data,
        Expression<Action<TMockDelegate, TDataEntry>> argumentsMapping,
        Func<TDataEntry, TResult> resultMapping,
        Func<TResult, TResult, bool>? resultVerifier = null)
    {
        _delegate = mock;
        _data = data.GetEnumerator();

        ArgumentMapping = BuildArgumentSelector(argumentsMapping);
        ResultMapping = resultMapping;
        ResultVerifier = resultVerifier ?? DefaultResultVerifier;

        return;

        bool DefaultResultVerifier(TResult expected, TResult actual)
            => expected is null ? actual is null : expected.Equals(actual);
    }

    public void Dispose()
    {
        _data.Dispose();
    }

    /// <summary>
    /// Logger for this trainer to use.
    /// </summary>
    [Injection]
    public LoggerComponent? Logger { get; init; }

    /// <summary>
    /// The mock function to train.
    /// </summary>
    private readonly MockDelegate<TMockDelegate> _delegate;

    /// <summary>
    /// The functor to select arguments from the data entry,
    /// and then invoke the mock function to acquire the result.
    /// </summary>
    public Func<TDataEntry, object?[]> ArgumentMapping { get; }

    /// <summary>
    /// The functor to select expected result from the data entry.
    /// </summary>
    public Func<TDataEntry, TResult> ResultMapping { get; }

    /// <summary>
    /// If this functor returns false, the reflection procedure will be conducted.
    /// </summary>
    public Func<TResult, TResult, bool> ResultVerifier { get; }

    /// <summary>
    /// When there are more invocations than this threshold in the memory, these invocations will be compressed.
    /// Set it to null to disable the compression. It is null by default.
    /// </summary>
    public int? CompressionThreshold { get; set; } = null;

    /// <summary>
    /// When there are more invocations than this threshold in the memory, oldest correct invocations will be removed,
    /// and if current invocation is correct and there is no correct invocations in the context,
    /// then current invocation will be dropped.
    /// Set it to null to disable the removal. It is null by default.
    /// </summary>
    public int? ReplacementThreshold { get; set; } = null;

    /// <summary>
    /// The metrics for humans to assess the performance.
    /// </summary>
    public Dictionary<string, Func<TrainerMetrics<TResult>>> Metrics { get; } = new();

    private readonly IEnumerator<TDataEntry> _data;

    private Dictionary<string, TrainerMetrics<TResult>> BuildMetrics()
        => Metrics.ToDictionary(pair => pair.Key, pair => pair.Value());

    /// <summary>
    /// Train the mock function with specified count of data entries.
    /// </summary>
    /// <param name="count">Count of data entries to use.</param>
    /// <returns>
    /// The data actually used in the training.
    /// If there is no enough data, then the count will be less than the specified count.
    /// </returns>
    public async Task<int> Train(int count)
    {
        AnsiConsole.Write(new Rule($"{Stage.Training} with {count} Data Entries"));
        Logger?.LogInformation("Training started with {Count} data entries.", count);
        var session = new TrainerSession<TResult>(BuildMetrics(), Stage.Training, count);
        var trainedCount = await ExecuteSession(session, stage => stage.Index >= count);
        var contextLength = _delegate.Memory.Invocations.Count;
        var positiveCount = _delegate.Memory.Invocations.Count(
            invocation => invocation.IsPassed());
        var positiveRatio = positiveCount / (double)_delegate.Memory.Invocations.Count;
        AnsiConsole.MarkupLine("[bold lightskyblue1]Training Finished:[/] Passed Invocations " +
                               $"{positiveCount}/{contextLength} ({positiveRatio:P2})");
        Logger?.LogDetails(LogLevel.Information,
            "Training finished.", details =>
            {
                details.Count = trainedCount;
                details.Metrics = session.ToStringMetrics();
                details.Memroy = _delegate.Memory.LocalMessages;
                details.ContextLength = contextLength;
                details.PositiveRatio = positiveRatio;
                details.TotalTokenUsage = _delegate.Memory.TokenUsages.Values.Sum();
                details.TokenUsages = _delegate.Memory.TokenUsages;
                details.AverageMilliseconds = session.AverageDuration.TotalMilliseconds;
            });
        return trainedCount;
    }

    public async Task<int> Evaluate(int count)
    {
        AnsiConsole.Write(new Rule($"{Stage.Evaluation} with {count} Data Entries"));
        Logger?.LogDetails(LogLevel.Information,
            "Evaluation started.", details =>
            {
                details.Count = count;
                details.Memory = _delegate.Memory.LocalMessages;
            });
        var session = new TrainerSession<TResult>(BuildMetrics(), Stage.Evaluation, count);
        var evaluatedCount = await ExecuteSession(session,
            stage => stage.Index >= count);
        Logger?.LogDetails(LogLevel.Information,
            "Evaluation finished.", details =>
            {
                details.Count = evaluatedCount;
                details.Metrics = session.ToStringMetrics();
                details.TotalTokenUsage = _delegate.Memory.TokenUsages.Values.Sum();
                details.TokenUsages = _delegate.Memory.TokenUsages;
                details.AverageMilliseconds = session.AverageDuration.TotalMilliseconds;
            });
        return evaluatedCount;
    }

    private async Task<int> ExecuteSession(
        TrainerSession<TResult> session,
        Func<TrainerSession<TResult>, bool> stoppingCondition)
    {
        await AnsiConsole.Status()
            .Spinner(Spinner.Known.Dots)
            .SpinnerStyle(new Style(Color.LightSkyBlue1))
            .StartAsync($"{session.Stage}...", async context =>
            {
                var elapsedTime = TimeSpan.Zero;

                while (!stoppingCondition(session))
                {
                    var timestamp = Stopwatch.GetTimestamp();

                    var status = session.Count == null
                        ? $"{session.Stage} {session.Index}..."
                        : $"{session.Stage} {session.Index}/{session.Count}... " +
                          $"({(elapsedTime * (session.Count - session.Index)).Value
                              .ToSmartString(1)})";
                    context.Status($"[bold lightskyblue1]{status}[/]");

                    if (!_data.MoveNext())
                    {
                        Logger?.LogInformation("TrainerSession<TResult> stopped: no more data from the data source. " +
                                               "{Count} data entries has been used.", session.Index);
                        AnsiConsole.MarkupLine(
                            $"[bold white] TrainerSession<TResult> Finished: {session.Index} data entries has been used.");
                        break;
                    }

                    await ExecuteEntry(session, _data.Current);

                    elapsedTime = TimeSpan.FromTicks(Stopwatch.GetTimestamp() - timestamp);
                }
            });
        return session.Index;
    }

    private async Task ExecuteEntry(TrainerSession<TResult> session, TDataEntry entry)
    {
        var timestamp = Stopwatch.GetTimestamp();

        MockInvocation invocation;
        TResult expectedResult, actualResult;
        string remarks;
        while (true)
        {
            try
            {
                using (_delegate.NewMemoryScope(out var memory))
                {
                    invocation = _delegate.MakeInvocation(ArgumentMapping(entry));
                    remarks = invocation.Remarks.AsString;
                    expectedResult = ResultMapping(entry);
                    actualResult = (TResult)_delegate.DeserializeResults(invocation.Results)!;
                    if (session.Stage == Stage.Training)
                        memory.Commit();
                    break;
                }
            }
            catch (Exception error)
            {
                if (error is AggregateException aggregate)
                    error = aggregate.InnerException!;
                Logger?.LogException(error, "Exception occurs while executing the mock function.");
            }
        }
        
        // Update metrics.
        session.ReportResult(expectedResult, actualResult);
        
        var bsonExpectedResult = _delegate.SerializeResults(expectedResult!);

        var metrics = session.Metrics
            .Select(pair => new KeyValuePair<string, string>(pair.Key, pair.Value.ToText()))
            .ToDictionary();

        if (session.Stage == Stage.Evaluation)
        {
            var duration = Stopwatch.GetElapsedTime(timestamp);
            Logger?.LogDetails(LogLevel.Information, "Evaluated with a data entry.",
                details =>
                {
                    details.Stage = session.Stage.ToString();
                    details.Index = session.Index;
                    details.Arguments = invocation.Arguments;
                    details.Remarks = remarks;
                    details.ActualResults = invocation.Results;
                    details.ExpectedResults = bsonExpectedResult;
                    details.Metrics = metrics;
                    details.ElapsedMilliseconds = duration.TotalMilliseconds;
                });
            session.ReportDuration(duration);
            ConcludeEntry(session, timestamp: timestamp);
            return;
        }

        if (ResultVerifier(expectedResult, actualResult))
        {
            var duration = Stopwatch.GetElapsedTime(timestamp);
            invocation.SetTrainerFlag(TrainerInvocationFlag.Passed);
            Logger?.LogDetails(LogLevel.Information, "Result verification passed.",
                details =>
                {
                    details.Stage = session.Stage.ToString();
                    details.Index = session.Index;
                    details.Arguments = invocation.Arguments;
                    details.Results = invocation.Results;
                    details.Remarks = remarks;
                    details.Metrics = metrics;
                    details.ElapsedMilliseconds = duration.TotalMilliseconds;
                });
            session.ReportDuration(duration);
            ConcludeEntry(session, timestamp, "[green]Passed[/]");
        }
        else // Conduct reflection.
        {
            invocation.SetTrainerFlag(TrainerInvocationFlag.Failed);
            string reflection;
            using (_delegate.NewMemoryScope(out _))
            {
                reflection = await _delegate.Reflect(invocation.Arguments, invocation.Results,
                    remarks, bsonExpectedResult);
                invocation.Remarks = $"I previously given the wrong result {invocation.Results}, " +
                                     $"but the correct answer is {bsonExpectedResult}. {reflection}";
                invocation.Results = bsonExpectedResult;
                invocation.UpdateText();
            }

            var duration = Stopwatch.GetElapsedTime(timestamp);
            Logger?.LogDetails(LogLevel.Information, "Result verification failed. Reflection conducted.",
                details =>
                {
                    details.Stage = session.Stage.ToString();
                    details.Index = session.Index;
                    details.Arguments = invocation.Arguments;
                    details.Remarks = remarks;
                    details.ActualResults = invocation.Results;
                    details.ExpectedResults = bsonExpectedResult;
                    details.Reflection = reflection;
                    details.Metrics = metrics;
                    details.ElapsedMilliseconds = duration.TotalMilliseconds;
                });
            session.ReportDuration(duration);
            ConcludeEntry(session, timestamp, "[red]Failed[/]");
        }

        #region Replacement

        if (ReplacementThreshold != null && _delegate.Memory.Invocations.Count > ReplacementThreshold.Value)
        {
            var targetInvocation = _delegate.Memory.Invocations.FirstOrDefault(
                candidate => candidate.IsPassed());
            if (targetInvocation == null && invocation.IsPassed())
            {
                _delegate.Memory.Remove(invocation);
                Logger?.LogDetails(LogLevel.Information,
                    "Due to replacement policy, current passed invocation is dropped.",
                    details =>
                    {
                        details.Arguments = invocation.Arguments;
                        details.Results = invocation.Results;
                        details.Remarks = remarks;
                        details.Verification = invocation.IsPassed()
                            ? "Passed"
                            : "Failed";
                        details.Memory = _delegate.Memory.LocalMessages;
                    });
            }
            else
            {
                targetInvocation ??= _delegate.Memory.Invocations.First();
                _delegate.Memory.Remove(targetInvocation);
                Logger?.LogDetails(LogLevel.Information,
                    "Due to replacement policy, the oldest invocation is removed.",
                    details =>
                    {
                        details.Arguments = targetInvocation.Arguments;
                        details.Results = targetInvocation.Results;
                        details.Remarks = targetInvocation.Remarks;
                        details.Verification = targetInvocation.IsPassed()
                            ? "Passed"
                            : "Failed";
                        details.Memory = _delegate.Memory.LocalMessages;
                    });
            }
        }

        #endregion

        #region Compression

        if (CompressionThreshold == null ||
            _delegate.Memory.Invocations.Count <= CompressionThreshold.Value) return;
        {
            var compressed = await RemarksCompressor.Compress(_delegate);
            _delegate.Memory.ClearInvocations();
            _delegate.Memory.Add(new AgentRequestMessage(
                    $"""
                     Here are notes summarized by yourself to help you avoid mistakes and maximize accuracy:
                     {compressed}
                     """)
                {
                    ParticipantName = "Compressor"
                }
            );
            Logger?.LogDetails(LogLevel.Information, "Remarks compressed.",
                details =>
                {
                    details.CompressedRemarks = compressed;
                    details.Memory = _delegate.Memory.LocalMessages;
                });
        }

        #endregion
    }

    private static void ConcludeEntry(TrainerSession<TResult> session, long timestamp, string status = "")
    {
        var duration = Stopwatch.GetElapsedTime(timestamp);
        var tree = new Tree(new Markup(
            $"[bold lightskyblue1]{session.Stage} {session.Index}" +
            $"{(session.Count != null ? "/" + session.Count : string.Empty)}:[/] " + status +
            $" ({duration
                .TotalMilliseconds:F0}ms - Average {session.AverageDuration.TotalMilliseconds}ms)"));
        foreach (var (name, value) in session.Metrics
                     .Select(pair => new KeyValuePair<string, string>(pair.Key, pair.Value.ToMarkup())))
        {
            tree.Nodes.Add(new TreeNode(new Markup("[bold white]" + name + $":[/] {value}".PadRight(40))));
        }

        AnsiConsole.Write(tree);
    }

    private static Func<TDataEntry, object?[]> BuildArgumentSelector(
        Expression<Action<TMockDelegate, TDataEntry>> expression)
    {
        var entry = expression.Parameters[1];
        var arguments = ((InvocationExpression)expression.Body).Arguments.Select(argument => 
            Expression.Convert(argument, typeof(object)));
        return Expression.Lambda<Func<TDataEntry, object?[]>>(
                Expression.NewArrayInit(typeof(object), arguments), entry)
            .Compile();
    }
}