﻿using System.ClientModel;
using System.ClientModel.Primitives;
using Robotless.Modules.OpenAi;
using Robotless.Modules.OpenAi.Chat;
using Robotless.Framework;
using Robotless.Kernel;
using Robotless.Modules.AiAgent.Messages;
using Robotless.Modules.Injecting;
using Robotless.Modules.Logging;

namespace Robotless.Modules.AiAgent.Agents;

public class OpenAiAgent(ApiKeyCredential credential, string model, OpenAIClientOptions? options = null) 
    : Entity, IAgent
{
    public OpenAiAgent(string key, string model) : this(new ApiKeyCredential(key), model)
    {}

    private readonly ChatClient _client =
        new OpenAIClient(credential, options).GetChatClient(model);

    [ComponentDependency] public LoggerComponent? Logger { get; init; }
    
    public async Task<AgentResponseMessage> Chat(
        AgentRequestMessage request, 
        IAgentMemory memory, 
        ChatCompletionOptions? options = null,
        IReadOnlyDictionary<string, IAgentTool>? tools = null, CancellationToken cancellation = default)
    {
        memory.Add(request);
        
        while (!cancellation.IsCancellationRequested)
        {
            var completion = (await _client.CompleteChatAsync(memory, options, cancellation)).Value;
            
            switch (completion.FinishReason)
            {
                case ChatFinishReason.Stop:
                    var response = new AgentResponseMessage(completion);
                    memory.Add(response);
                    return response;
                case ChatFinishReason.ToolCalls:
                    foreach (var toolCall in completion.ToolCalls)
                    {
                        if (tools?.TryGetValue(toolCall.FunctionName, out var tool) != true)
                            throw Logger.PlatformException(
                                $"Failed to find the tool \"{toolCall.FunctionName}\"");
                        memory.Add(new AgentToolMessage(toolCall.Id,
                            tool!.Invoke(toolCall.FunctionArguments.ToString())));
                    }
                    break;
                case ChatFinishReason.ContentFilter:
                    throw Logger.PlatformException(
                        "Completion is omitted by the content filter.");
                case ChatFinishReason.Length:
                    throw Logger.PlatformException(
                        "Completion exceeds the length limit.");
                case ChatFinishReason.FunctionCall:
                    throw Logger.PlatformException(
                        "Function calls are not supported.");
                default:
                    throw new Exception($"Unknown completion finish reason \"{completion.FinishReason}\".");
            }
        }
        throw new OperationCanceledException("Completion is cancelled.");
    }
}

public static class OpenAiAgentFactory
{
    public static IInjectionContainer AddOpenAiAgent(this IInjectionContainer container, 
        ApiKeyCredential key, string model)
    {
        container.AddTransient(CreateAgent);
        container.AddRedirection<IAgent, OpenAiAgent>();
        return container;
        
        OpenAiAgent CreateAgent(IInjectionProvider provider, InjectionRequester requester)
        {
            return new OpenAiAgent(key, model)
            {
                Workspace = provider.RequireInjection<IWorkspace>(),
                Logger = provider.GetInjection<LoggerComponent>(),
            };
        }
    }
    
    public static IInjectionContainer AddOpenSourceAgent(this IInjectionContainer container,
        Uri endpoint, ApiKeyCredential key, string model)
    {
        container.AddTransient(CreateAgent);
        container.AddRedirection<IAgent, OpenAiAgent>();
        return container;
        
        OpenAiAgent CreateAgent(IInjectionProvider provider, InjectionRequester requester)
        {
            return new OpenAiAgent(key, model, new OpenAIClientOptions()
            {
                Endpoint = endpoint,
                RetryPolicy = ClientRetryPolicy.Default
            })
            {
                Workspace = provider.RequireInjection<IWorkspace>(),
                Logger = provider.GetInjection<LoggerComponent>(),
            };
        }
    }
}