import argparse
import json
import os
import random
import time

import cohere
import openai
import requests
from levels.constants import all_objecsts, base_ingridients, capacity
from levels.utils import compute_dependency
from mistralai import Mistral
from openai import AzureOpenAI, OpenAI
from anthropic import AnthropicBedrock
from overcooked.game import World

# Client
api_key = os.getenv("AVIOR_API_KEY") # Not support yet
api_base = "http://avior.mlfoundry.com/live-inference/v1"
client_avior = OpenAI(
    api_key=api_key,
    base_url=api_base
)

client_gpt = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL")
)

# Azure OpenAI client - configurable through environment variables or defaults to provided working config
client_4o_mini = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-10-21"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT", "https://llm-co-ncus.openai.azure.com/")
)

client_4o = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT", "https://llm-co-ncus.openai.azure.com/")
)

api_key = os.environ["MISTRAL_API_KEY"]
client_mistral = Mistral(api_key=api_key)

api_key = os.environ["COHERE_API_KEY"]
client_cohere = cohere.ClientV2(api_key=api_key)

api_key = os.environ["AZURE_OPENAI_API_KEY"]
client_openai = AzureOpenAI(
    azure_endpoint="https://llm-co-ncus.openai.azure.com/",
    api_key=api_key,
    api_version="2024-08-01-preview")

client_openai_mini = AzureOpenAI(
    azure_endpoint="https://llm-co-ncus.openai.azure.com/",
    api_key=api_key,
    api_version="2024-10-21")

client_bedrock = AnthropicBedrock(
    aws_access_key=os.environ["AWS_ACESS_KEY"],
    aws_secret_key=os.environ["AWS_SECRET_KEY"],
    aws_region="us-east-2",
)


models_cost = {
    # cost per million tokens
    "mistral-large-latest": {
        "input": 2,
        "output": 6
    },
    "mistral-small-latest": {
        "input": 0.1,
        "output": 0.3
    },

    "meta-llama/Llama-3.1-70B-Instruct": {
        "input": 0.8,
        "output": 2.80
    },
    "meta-llama/Meta-Llama-3.1-8B-Instruct": {
        "input": 0.1,
        "output": 0.4
    },

    "command-a-03-2025": {
        "input": 2.5,
        "output": 10
    },
    "command-r7b-12-2024": {
        "input": 0.0375,
        "output": 0.15
    },

    "gpt-4o-v2": {
        "input": 2.5,
        "output": 10
    },
    "gpt-4o-mini": {
        "input": 0.150,
        "output": 0.60
    },
    "claude-37": {
        "input": 3,
        "output": 15
    },
    "Qwen/Qwen2.5-32B-Instruct": {
        "input": 0.4,
        "output": 1.4
    }
}


key_id = 0
def rules(env, notes=True):
    prompt = ''
    prompt += 'The available actions are :\n'
    prompt += '1) goto: goto a tool location \n'
    prompt += '\t\tExample: goto_agent0_storage0 \n'
    prompt += '2) get: get some object from a tool \n'
    prompt += '\t\t Example: get_agent0_beef_storage0 \n'
    prompt += '3) put: put some object into a tool \n'
    prompt += '\t\t Example: put_beef_agent0_blender0 \n'
    prompt += '4) activate: activate the tool to cook all ingredients inside the tool into a different tools \n'
    prompt += '\t\t Example: activate_agent0_blender0 \n'
    prompt += '5) noop: not performing any actions \n'
    prompt += '\t\t Example: noop_agent0 \n'

    prompt += 'Sometimes the system will give you error messages. Please consider these error messages when executing actions.  \n'
    prompt += 'You need to specify action for your agent. You can only specify one action at a time. \n'
    # prompt += 'You need to specify action for all of the agents, **except human**. They all have different agent numbers. Do not assign actions to the same agent more than once. \n'
    prompt += '\n'

    if notes:
        prompt += 'When the tools reach its capacity, you need to take stuff out. Otherwise, you cannot put items inside.\n'
        prompt += 'When you are holding objects, you cannot get any more objects. \n'
        prompt += "When you are holding objects, you cannot activate tools. \n"
        prompt += "Afer you cooked a required dish, you need to put it into the servingtable. \n"
        prompt += 'You can only pick up objects from the tool location, if you are located at the tool location. \n'
        prompt += "When you activate any tools, make sure all the items inside the tool are respecting the recipes. Otherwise, you will cook waste. Avoid waste at all cost. \n"

        prompt += "*** You should mix salad in the mixer. To make salad you should chop veggies first. *** \n"
        prompt += "*** If the tool is occupied, indicated by the occupy() predicate, you cannot get objects from it or put objects into it. *** \n"
        prompt += "*** The food orders are keep coming. You should finish as many dishes as possible and finish every dish as soon as possible. Please deliver the order to the serveringtable when it is finished. *** \n"
        prompt += "*** The dish will expire after the lifetime reaches 0 and it's not at the serveringtable. Please avoid this. *** "

    prompt += 'Here are the recipes: \n'
    prompt += recipes(env)

    prompt += generate_tool_descriptions(env)
    prompt += '\n\n'

    return prompt

def recipes(env: World):
    import copy
    tasks =  copy.deepcopy(env.task_manager._all_tasks)
    tasks.append('porkMeatcake')
    required_components = []

    for task in tasks:
        required_components.extend(compute_dependency(task)[0])
    required_components = set(required_components)

    prompt = '\n'
    with open('assets/recipe.json', 'r') as f:
        recipe = json.load(f)

    task_related_objects = set()
    task_related_tools = set()
    for dish, value in recipe.items():
        if dish in required_components:
            ingredients = value['ingredients']
            location = value['location']
            prompt += f'Cook {dish} at: \n'
            prompt += f' -- location: {location} \n'
            prompt += f' -- with ingredients: '
            task_related_objects.add(dish)
            task_related_tools.add(location)
            for ingredient in ingredients:
                task_related_objects.add(ingredient)
                prompt += f'    {ingredient}, '
            prompt += '\n'

    prompt += 'The following objects are available: \n'
    for idx, item in enumerate(task_related_objects):
        prompt += f' --{idx+1}) {item} \n'
    prompt += "The objecsts are cooked using tools or are just base ingredients. \n"

    prompt += "Among them, the following are base ingredients: \n"
    cnt = 1
    for idx, item in enumerate(task_related_objects):
        if item in base_ingridients:
            prompt += f" --{cnt}) {item} \n"
            cnt += 1
    prompt += "You can only obtain base ingredients from the storage initially.  \n"

    prompt += 'Additional rules: \n'
    for tool_name, tool in env.name_mapping.items():
        cap = capacity[tool_name[:-1]]
        num = cap
        if num == -1:
            num = 'infinite'

        prompt += f'You can place up to {num} item into the {tool_name} \n'
        prompt += f'You can place up to {num} item into the {tool_name} \n'

    return prompt

def generate_tool_descriptions(env: World):
    prompt = '** Only ** the following tools are available: \n'
    for tool_name, tool in env.name_mapping.items():
        prompt += f'{tool_name}, '

    prompt += 'You cannot pick up these tools. You can only use those tools at the corresponding location.'

    prompt += '\n'
    return prompt


def prepend_prompt(prompt, add, verbose=True):
    if verbose:
        print(add)
    return prompt + add

def prepend_history(history, add, role='user', verbose=True):
    if verbose:
        print(f'\n\n[[{role}]]\n\n' + add)
    assert role in ['user', 'assistant']
    history.append((role, add))
    return history


def query_llm(history, max_tokens=100, temperature=0.0, stop=['\n', '\n\n'], model=None):
    # print("HISTORY: ", history)
    # print("MODEL: ", model)
    if type(history) == str:
        history = [('user', history)]

    chat_history = []
    for i in history:
        if i[0] == 'user':
            chat_history.append({
                'role': 'user',
                'content': i[1]
            })
        elif i[0] == 'assistant':
            chat_history.append({
                'role': 'assistant',
                'content': i[1]
            })
        else:
            print(f"Role: {i[0]}")
            raise NotImplementedError

    print('>>>>', chat_history)

    if 'mistral' in model:
        response = client_mistral.chat.complete(
            model=model,
            messages=chat_history,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop
        )
        
        price = response.usage.prompt_tokens * models_cost[model]['input'] + response.usage.completion_tokens * models_cost[model]['output']
        price /= 1e6
        return response.choices[0].message.content, price, {"input": response.usage.prompt_tokens, "output": response.usage.completion_tokens}
    
    elif 'command' in model:
        response = client_cohere.chat(
            model=model,
            messages=chat_history,
            temperature=temperature,
            max_tokens=max_tokens,
            stop_sequences=stop
        )
        price = response.usage.billed_units.input_tokens * models_cost[model]['input'] + response.usage.billed_units.output_tokens * models_cost[model]['output']
        price /= 1e6
        return response.message.content[0].text, price, {"input": response.usage.billed_units.input_tokens, "output": response.usage.billed_units.output_tokens}
    
    elif 'gpt' in model:
        if 'mini'  in model:
            response = client_openai_mini.chat.completions.create(
                model=model,
                messages=chat_history,
                temperature=temperature,
                max_tokens=max_tokens,
                stop=stop
            )
        else:
           
            response = client_openai.chat.completions.create(
                model=model,
                messages=chat_history,
                temperature=temperature,
                max_tokens=max_tokens,
                stop=stop
            )
        price = response.usage.prompt_tokens * models_cost[model]['input'] + response.usage.completion_tokens * models_cost[model]['output']
        price /= 1e6
        return response.choices[0].message.content, price, {"input": response.usage.prompt_tokens, "output": response.usage.completion_tokens}

    elif 'claude-37' in model:
        model_name = "arn:aws:bedrock:us-east-2:288380904485:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
        response = client_bedrock.messages.create(
            model=model_name,
            messages=chat_history,
            temperature=temperature,
            max_tokens=max_tokens,
            stop_sequences=stop
        )
        price = response.usage.input_tokens * models_cost[model]['input'] + response.usage.output_tokens * models_cost[model]['output']
        price /= 1e6
        return response.content[0].text, price, {"input": response.usage.input_tokens, "output": response.usage.output_tokens}

    else:
        response = client_avior.chat.completions.create(
                    model=model,
                    messages=chat_history,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    stop=stop
                )

        price = response.usage.prompt_tokens * models_cost.get(model, {"input": 0, "output": 0})['input'] + \
                response.usage.completion_tokens * models_cost.get(model, {"input": 0, "output": 0})['output']
        price /= 1e6
        return response.choices[0].message.content, price, {"input": response.usage.prompt_tokens, "output": response.usage.completion_tokens}

