import argparse
import json
import os
import pickle
import time
import uuid
import openai

import litellm
import numpy as np
from pwp.bench import PwPBench
from pwp.prompts.cua_prompt import SCREENSHOT_PROMPT, SOM_ENABLED_PROMPT
from pwp.prompts.owl_prompt_cua import (initial_message, later_messages,
                                    system_message)
from pwp.prompts.prompts import get_prompt_cua_categ1
from pwp.tools.functions_cua_owl import FUNCTIONS
from pwp.tools.tools import cua_tool_categ1
from pwp.utils.llm_utils import encode_image
from pwp.utils.utils import get_images_from_text

# litellm.set_verbose=True
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="mPLUG/GUI-Owl-32B")
parser.add_argument("--system_prompt", type=str, default="owl_prompt_cua")
parser.add_argument("--output_dir", type=str, default="logs_icml_cua_categ1")
parser.add_argument("--temperature", type=float, default=0.3)
parser.add_argument("--max_iters", type=int, default=20)
parser.add_argument("--max_images", type=int, default=5)
parser.add_argument("--task", type=str, default="vscode")
args = parser.parse_args()

bench = PwPBench(args.task)
dataset = bench.get_dataset()

import importlib

# Load the system prompt from prompts.{args.system_prompt}
system_prompt = importlib.import_module(f"pwp.prompts.{args.system_prompt}").system_message

# if args.model.find('claude') != -1:
#     system_prompt = system_prompt.replace(SOM_ENABLED_PROMPT, "")
# system_prompt = system_prompt.replace(SCREENSHOT_PROMPT, "")

OUTPUT_DIR = os.path.join(args.output_dir, args.task, args.model)
os.makedirs(OUTPUT_DIR, exist_ok=True)


workdirs = {
    "humaneval": "/home/devuser/evaluation",
    "swebench": "/testbed",
    "swtbench": "/testbed",
    "swebench-java": "/testbed",
    "dsbench": "/home/devuser/evaluation",
    "chartmimic": "/home/devuser/evaluation",
    "intercode": "/home/devuser/evaluation",
    "design2code": "/home/devuser/evaluation",
    "canitedit": "/home/devuser/evaluation",
    "resq": "/home/devuser/evaluation",
    "minictx": "/home/devuser/evaluation",
    "bird": "/home/devuser/evaluation",
    "vscode": "/home/devuser/",
    "swebench_mm": "/testbed",
    "nocode": "/home/devuser/evaluation",
}

# Update workdirs to precise location for some of the benchmarks
import time

time.sleep(5)

run_as_root = {"swebench", "swtbench", "swebench-java", "swebench_mm"}





def screenshot_message(env):
    if args.model.find("claude") != -1 or args.model.lower().find("owl") != -1:
        content = [
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{encode_image(env.render().resize((1920, 1080)))}",
                    "detail": "high",
                },
            }
        ]
    else:
        # Return a message with the SoM
        main_image = env.render()
        som_image = env.get_som_image(main_image)
        content = [
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{encode_image(main_image)}",
                    "detail": "high",
                },
            },
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{encode_image(som_image[0][0])}",
                    "detail": "high",
                },
            },
            {"type": "text", "text": "SOM Elements:\n" + som_image[0][2]},
        ]
    return content


def call_llm(model, messages, temperature=0.3, top_p=0.95):
    client = openai.OpenAI(base_url="http://localhost:4243/v1", api_key="EMPTY")
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        top_p=top_p,
        extra_body={"repetition_penalty": 1.05, "top_k": -1},
    )
    return response

for instance_num, row in enumerate(dataset):

    if args.task in ["intercode", "minictx"]:
        if args.task == "intercode":
            workdirs["intercode"] = "/home/devuser/evaluation/ctf/" + str(
                row["task_id"]
            )
        elif args.task == "minictx":
            dir_dict = {
                "PFR": "pfr",
                "PrimeNumberTheoremAnd": "PrimeNumberTheoremAnd",
                "hep_lean": "HepLean-v4.7",
                "htpi": "HTPILeanPackage4.7",
                "mathlib4": "mathlib4",
                "scilean": "SciLean",
            }
            workdirs["minictx"] = (
                f'/home/devuser/evaluation/test-envs/{dir_dict[row["file"].split("/")[0]]}'
            )

    INSTANCE_DIR = os.path.join(OUTPUT_DIR, f"task_{instance_num}")
    if os.path.exists(INSTANCE_DIR):
        continue
    os.makedirs(INSTANCE_DIR, exist_ok=True)

    def print(*args):
        text = " ".join([str(arg) for arg in args])
        with open(os.path.join(INSTANCE_DIR, "log.txt"), "a") as f:
            f.write(text + "\n")
            f.flush()

    print("")

    task_description = get_prompt_cua_categ1(row, args.task)
    user_prompt = initial_message.replace("<<task_description>>", task_description)
    if args.task == "minictx" or args.task == "intercode":
        messages = [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": user_prompt.replace("<<work_dir>>", workdirs[args.task]),
            },
        ]
    else:
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

    if args.model.find("gemini") != -1:
        messages = [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": user_prompt
                + "\nImportant: Use one tool call at a time. Do not use multiple tools at once.",
            },
        ]

    if args.task == "swebench_mm":
        images = get_images_from_text(user_prompt)
        for image in images:
            messages.append(
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{encode_image(image)}",
                                "detail": "high",
                            },
                        }
                    ],
                }
            )

    # TODO: Add screenshot to the messages

    env = bench.get_env(row)
    env.set_assisted_mode_on()

    all_tool_calls = []
    all_function_calls = []
    rewards = []

    returnFlag = False
    env.run_command("apt-get update -y", root=True)
    env.run_command("apt-get install -y tree", root=True)
    history = []

    for iter_num in range(args.max_iters):
        messages=[
                {
                    "role": "system",
                    "content": [
                        {"type": "text", "text": "You are a helpful assistant."},
                        {"type": "text", "text": system_message}
                    ],
                },
                # {
                #     "role": "user",
                #     "content": [
                #         {"type": "text", "text": user_prompt.format(instruction=instruction, history=history)}
                #     ]
                # }
            ]        
        if iter_num > 0:
            messages.append(
                {
                    "role": "user",
                    "content": [{"type": "text", "text": user_prompt + '\n\n' + 'History: ' + "\n".join(["Step " + str(i+1) + ": " + str(history[i]) for i in range(len(history))])}]
                    + screenshot_message(env),
                }
            )
        else:
            # messages.append({"role": "user", "content": screenshot_message(env)})
            messages.append({"role": "user", "content": [{"type": "text", "text": user_prompt}] + screenshot_message(env)})
        assert len(messages) == 2
        # Filter images such that only args.max_images are present
        image_count = 0
        image_indices = []
        for i in range(3, len(messages)):
            if 'content' in messages[i] and isinstance(messages[i]['content'], list):
                for j, cont in enumerate(messages[i]['content']):
                    if cont['type'] == 'image_url':
                        image_count += 1
                        image_indices.append((i, j))
                        
        images_to_del = image_count - args.max_images
        for i in range(images_to_del):
            try:
                assert len(messages[image_indices[i][0]]['content']) > 1 
                del messages[image_indices[i][0]]['content'][image_indices[i][1]]
            except Exception as e:
                print(i, image_indices[i])
                # breakpoint()


        # breakpoint()
        response = call_llm(
            args.model,
            messages,
            # cua_tool_categ1,
            temperature=args.temperature,
            # max_retries=5,
            # initial_delay=1,
        )
        # print(response)
        # `breakpoint`()
        try:
            # messages.append(response.choices[0].message)
            messages.append({"role": "assistant", "content": response.choices[0].message.content})
        except Exception as e:
            print(f"Error appending message: {e}")
            break

        print("🤖🗣️:", response.choices[0].message.content)
        print("🤖🛠️:", response.choices[0].message.tool_calls)

        # Extract the history, i.e content between <conclusion></conclusion> tags
        if '<conclusion>' in response.choices[0].message.content:
            print('Found conclusion:', response.choices[0].message.content.split("<conclusion>")[1].split("</conclusion>")[0])
            history.append(response.choices[0].message.content.split("<conclusion>")[1].split("</conclusion>")[0])
        else:
            print('No conclusion found')
            history.append(response.choices[0].message.content)

        # Change rule, that if there is no command issued, it can still be fine.
        returnFlag = False
        # if response.choices[0].message.tool_calls:
        if '<tool_call>' in response.choices[0].message.content:
            tool_calls = response.choices[0].message.content.split("<tool_call>")[-1].split("</tool_call>")[0]
            tool_calls = json.loads(tool_calls)
            tool_calls = [tool_calls]
            start_idx = len(messages)
            for tool_call in tool_calls:
                function_name, function_args, function_response = None, None, None
                try:
                    function_name = tool_call['name']
                    function_args = tool_call['arguments']
                    if 'action' in function_args and function_args['action'] == 'terminate':
                        returnFlag = True
                        break
                    if function_name == "bash":
                        if args.task in run_as_root:
                            function_args["root"] = True
                        function_args["workdir"] = workdirs[args.task]
                    if function_name == "computer_use":
                        function_name = 'computer_control'
                    function_response = FUNCTIONS[function_name](env, **function_args)
                    if len(function_response) > 20000:
                        function_response = (
                            function_response[:15000]
                            + "\n\n...truncated"
                            + function_response[-2000:]
                        )
                    messages.append(
                        {
                            "role": "user",
                            "content": "Response from the tool call: " + function_response,
                        }
                    )
                except Exception as e:
                    print(f"Error calling function {function_name}: {e}")
                    messages.append(
                        {
                            "role": "user",
                            "content": "Response from the tool call: " + str(e),
                        }
                    )
                    function_response = str(e)
                    function_name = "error" if function_name is None else function_name
                    function_args = {} if function_args is None else function_args
                all_tool_calls.append(tool_call)
                all_function_calls.append(
                    {
                        "name": function_name,
                        "args": function_args,
                        "response": function_response,
                    }
                )
                print("💻:", function_name, function_args, function_response)
                time.sleep(0.5)

        else:
            # returnFlag = True
            pass
        reward = bench.get_reward(env, row)
        print("🏆:", reward)
        rewards.append(reward)

        env.render().save(os.path.join(INSTANCE_DIR, f"screenshot_{iter_num}.png"))
        # breakpoint()
        # print('Here is what we have in the settings.json file:', env.run_command("cat /home/devuser/.config/Code - OSS/User/settings.json"))
        if returnFlag:
            break
        print("\n\n\n")

    # Save both the tool calls and the rewards
    with open(os.path.join(INSTANCE_DIR, "tool_calls.pkl"), "wb") as f:
        pickle.dump(all_tool_calls, f)
    with open(os.path.join(INSTANCE_DIR, "rewards.json"), "w") as f:
        json.dump(rewards, f)
    try:
        with open(os.path.join(INSTANCE_DIR, "function_calls.json"), "w") as f:
            json.dump(all_function_calls, f)
    except Exception as e:
        print(f"Error dumping function calls: {e}")
        with open(os.path.join(INSTANCE_DIR, "function_calls.pkl"), "wb") as f:
            pickle.dump(all_function_calls, f)

    # Save the messages
    with open(os.path.join(INSTANCE_DIR, "messages.pkl"), "wb") as f:
        pickle.dump(messages, f)

    del env
