# This script is a modified version of the MagenticOne scenario.py file.
# It adds cost tracking to the model clients.

import asyncio
import os
import warnings
from dataclasses import dataclass, field
from typing import Dict

import yaml
from autogen_agentchat.agents import CodeExecutorAgent
from autogen_agentchat.teams import MagenticOneGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import ChatCompletionClient
from autogen_ext.agents.file_surfer import FileSurfer
from autogen_ext.agents.magentic_one import MagenticOneCoderAgent
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor

# Suppress warnings about the requests.Session() not being closed
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)

@dataclass
class CostTracker:
    """Track API costs for different model clients."""
    costs: Dict[str, float] = field(default_factory=dict)
    
    def calculate_cost(self, prompt_tokens: int, completion_tokens: int, model: str) -> float:
        """Calculate cost based on token usage and model family."""
        if model == "gpt-4o-mini":
            prompt_cost_per_1k = 0.00015  
            completion_cost_per_1k = 0.0006  
        else:
            raise ValueError(f"Unsupported model: {model}")
            
        prompt_cost = (prompt_tokens / 1000) * prompt_cost_per_1k
        completion_cost = (completion_tokens / 1000) * completion_cost_per_1k
        return prompt_cost + completion_cost
    
    def add_cost(self, client_name: str, cost: float) -> None:
        """Add cost for a specific client."""
        self.costs[client_name] = self.costs.get(client_name, 0.0) + cost
        print(f"Current cost: {self.costs}")
    
    def get_total_cost(self) -> float:
        """Get total cost across all clients."""
        return sum(self.costs.values())
    
    def get_cost_summary(self) -> str:
        """Get a formatted summary of costs."""
        summary = "Cost Summary:\n"
        summary += f"Total Cost: ${self.get_total_cost():.4f}"
        return summary

async def main() -> None:
    # Initialize cost tracker
    cost_tracker = CostTracker()

    # Load model configuration and create the model client.
    with open("config.yaml", "r") as f:
        config = yaml.safe_load(f)

    # Create model clients with cost tracking
    orchestrator_client = ChatCompletionClient.load_component(config["orchestrator_client"])
    coder_client = ChatCompletionClient.load_component(config["coder_client"])
    web_surfer_client = ChatCompletionClient.load_component(config["web_surfer_client"])
    file_surfer_client = ChatCompletionClient.load_component(config["file_surfer_client"])
    
    # Add cost tracking to model clients
    for client, name in [
        (orchestrator_client, "Orchestrator"),
        (coder_client, "Coder"),
        (web_surfer_client, "WebSurfer"),
        (file_surfer_client, "FileSurfer")
    ]:
        original_create = client.create
        
        async def create_with_cost_tracking(*args, **kwargs):
            response = await original_create(*args, **kwargs)
            if hasattr(response, 'usage'):
                if hasattr(response.usage, 'total_cost'):
                    cost_tracker.add_cost(name, response.usage.total_cost)
                elif hasattr(response.usage, 'prompt_tokens') and hasattr(response.usage, 'completion_tokens'):
                    # Calculate cost manually if total_cost is not available
                    cost = cost_tracker.calculate_cost(
                        response.usage.prompt_tokens,
                        response.usage.completion_tokens,
                        "gpt-4o-mini"
                    )
                    cost_tracker.add_cost(name, cost)
            return response
        
        client.create = create_with_cost_tracking

    # Read the prompt
    prompt = ""
    with open("prompt.txt", "rt") as fh:
        prompt = fh.read().strip()
    filename = "__FILE_NAME__".strip()

    # Set up the team
    coder = MagenticOneCoderAgent(
        "Assistant",
        model_client = coder_client,
    )

    executor = CodeExecutorAgent("ComputerTerminal", code_executor=LocalCommandLineCodeExecutor())

    file_surfer = FileSurfer(
        name="FileSurfer",
        model_client = file_surfer_client,
    )
                
    web_surfer = MultimodalWebSurfer(
        name="WebSurfer",
        model_client = web_surfer_client,
        downloads_folder=os.getcwd(),
        debug_dir="logs",
        to_save_screenshots=True,
    )

    team = MagenticOneGroupChat(
        [coder, executor, file_surfer, web_surfer],
        model_client=orchestrator_client,
        max_turns=20,
        final_answer_prompt= f""",
We have completed the following task:

{prompt}

The above messages contain the conversation that took place to complete the task.
Read the above conversation and output a FINAL ANSWER to the question.
To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER]
Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
""".strip()
    )

    # Prepare the prompt
    filename_prompt = ""
    if len(filename) > 0:
        filename_prompt = f"The question is about a file, document or image, which can be accessed by the filename '{filename}' in the current working directory."
    task = f"{prompt}\n\n{filename_prompt}"

    # Run the task
    stream = team.run_stream(task=task.strip())
    await Console(stream)
    
    # Print cost summary after task completion
    print("\n" + cost_tracker.get_cost_summary())

if __name__ == "__main__":
    asyncio.run(main())
