﻿using System.ClientModel;
using OpenAI;
using OpenAI.Chat;
using Robotless.Framework;
using Robotless.Kernel;
using Robotless.Modules.Injecting;
using Robotless.Modules.Logging;

namespace Robotless.Modules.Agenting.Clients;

public class OpenAiAgent(ApiKeyCredential credential, string model, OpenAIClientOptions? options = null) : IAgent
{
    public OpenAiAgent(string key, string model) : this(new ApiKeyCredential(key), model)
    {}

    private readonly ChatClient _client =
        new OpenAIClient(credential, options).GetChatClient(model);

    [ComponentDependency] public LoggerComponent? Logger { get; init; }

    public IDictionary<string, ITool> Tools { get; } = new Dictionary<string, ITool>();

    public virtual async Task<AgentAssistantMessage> Complete(IAgentMemory memory, 
        ChatCompletionOptions? options = null, IReadOnlyDictionary<string, ITool>? tools = null,
        CancellationToken cancellation = default)
    {
        while (!cancellation.IsCancellationRequested)
        {
            var completion = (await _client.CompleteChatAsync(memory, options, cancellation)).Value;
            switch (completion.FinishReason)
            {
                case ChatFinishReason.Stop:
                    var response = new AgentAssistantMessage(completion);
                    memory.Add(response);
                    return response;
                case ChatFinishReason.ToolCalls:
                    foreach (var toolCall in completion.ToolCalls)
                    {
                        if (!Tools.TryGetValue(toolCall.FunctionName, out var tool) &&
                            tools?.TryGetValue(toolCall.FunctionName, out tool) != true)
                            throw Logger.PlatformException(
                                $"Failed to find the tool \"{toolCall.FunctionName}\"");
                        memory.Add(new AgentToolMessage(toolCall.Id,
                            tool!.Invoke(toolCall.FunctionArguments)));
                    }
                    break;
                case ChatFinishReason.ContentFilter:
                    throw Logger.PlatformException(
                        "Completion is omitted by the content filter.");
                case ChatFinishReason.Length:
                    throw Logger.PlatformException(
                        "Completion exceeds the length limit.");
                case ChatFinishReason.FunctionCall:
                    throw Logger.PlatformException(
                        "Function calls are not supported.");
                default:
                    throw new ArgumentOutOfRangeException(nameof(ChatFinishReason));
            }
        }
        throw new OperationCanceledException("Completion is cancelled.");
    }
}

public static class OpenAiAgentFactory
{
    public static IInjectionContainer AddOpenAiAgent(this IInjectionContainer container, 
        ApiKeyCredential key, string model, OpenAIClientOptions? options = null)
    {
        container[typeof(OpenAiAgent)] = CreateAgent;
        
        return container;
        
        OpenAiAgent CreateAgent(IInjectionProvider provider, InjectionTarget? target)
        {
            return new OpenAiAgent(key, model, options)
            {
                Logger = provider.GetInjection<IPlatformLogger>()
            };
        }
    }
}