import sys

from PIL import Image, ImageFont, ImageDraw
import numpy as np
import requests
from PIL import Image
import re
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import warnings
warnings.filterwarnings('ignore')
import os


class VLM_Agent(object):
    def __init__(self, task="ghost_static", device="cuda:0", res_dir=""):
        self.task = task
        self.device = device
        self.res_dir = res_dir
        model_id = "llava-hf/llava-1.5-7b-hf"
        # default: Load the model on the available device(s)
        self.model = LlavaForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            force_download=False,
            device_map=self.device
        )

        self.processor = AutoProcessor.from_pretrained(model_id)
        #
        self.last_steer = 0
        self.last_throttle = 0
        self.last_ego_orientation = 0
        self.first_time = True

    def select_action(self, obs, time_step=0, seed=0):
        rgb_array, selected_ego_velocity, ego_orientation = obs
        rgb_array = rgb_array.astype(np.uint8)  # 转换为 uint8 类型
        # print("rgb_array:", rgb_array.shape)
        # 将 numpy 数组转换为 PIL 图像
        image = Image.fromarray(rgb_array, 'RGB')
        # image.save("/home/user/cg/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/datasets/xcm_nature/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/home/xcm/LLM_RL/images/{}/{}.jpg".format(seed, time_step))

        if self.task == "ghost_static":
            if self.first_time is True:
                self.last_ego_orientation = ego_orientation
                self.first_time = False
            ego_orientation = ego_orientation - self.last_ego_orientation  # 校准
            statement = "The scene where ego vehicles are located is the ghost probe scene, that is, pedestrians suddenly cross the road. Please describe the content of the image, determine if there is a pedestrian likely to appear, or cross the road, if there is, and the distance is close, ego need to consider adjusting action to avoid collision. " \
                        "It should be noted that the throttle range is [-1, 1], the throttle <=0 indicates slowing down, and the throttle >0 indicates different acceleration. Steering ranges from -1 to 1, with steering >0 indicating a right turn, steering <0 indicating a left turn, and steering=0 indicating straight driving. " \
                        "In addition to the image, ego's current speed is {:.2f}m/s, yaw is {:.2f}. If there is no visible pedestrian in the image or the distance of the pedestrian is relatively far, ego should set the throttle =1.0 to accelerate. " \
                        "If ego's direction deviates from the current lane, steering adjustment is required. Finally, design reasonable throttle and steering output " \
                        "values for the ego based on the description of the image, the current speed and yaw. Please follow the following example to output content:\
                        \nDescribe:xxx\
                        \nthrottle:xxx\
                        \nsteer:xxx".format(selected_ego_velocity, ego_orientation)

        else:  # highway
            statement = "The scene where ego is located is a high-speed scene in automatic driving. Ego needs to " \
                        "avoid collisions on highways, try to stay in the same lane. Please describe the image in detail to determine whether there is a mmediate danger of collision. If so, ego needs to consider slowing down (set throttle <0). Furthermore, please describe the image in detail to determine if the ego is going through a bend. If so, the ego's steering needs to be adjusted. " \
                        "It should be noted that the throttle range is -1 ~ 1, the throttle <=0 indicates slowing down, and the throttle >0 indicates different acceleration. Steering ranges from -1 to 1, with steering >0 indicating a right turn, steering <0 indicating a left turn, and steering=0 indicating straight driving. " \
                        "In addition to the image, ego's current speed is {:.2f}m/s, yaw is {:.2f}. If there is no immediate danger of collision, ego can set throttle=1 to accelerate. Finally, design reasonable throttle and steering output " \
                        "values for the ego based on the description of the image, the current speed and yaw. Please follow the following template to output content:\
                        \nDescribe:xxx\
                        \nthrottle:xxx\
                        \nsteer:xxx".format(selected_ego_velocity, ego_orientation)

        messages = [
            {
                "role": "system",
                "content": "You are an AI assistant integrated into an autonomous driving system in Carla. Your task is to "
                        "analyze the RGB camera image taken from the front of the ego vehicle and output throttle and "
                        "steering values."
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image
                    },
                    {"type": "text", "text": statement},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)

        # Preprocess the inputs
        inputs = self.processor(images=image, text=prompt, return_tensors='pt').to(self.device)

        output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
        output_text = self.processor.decode(output[0], skip_special_tokens=True)

        assistant_index = output_text.find("ASSISTANT:")
        # 提取"ASSISTANT:"之后的所有字符
        if assistant_index != -1:
            assistant_text = output_text[assistant_index:]
            # print(assistant_text)
        else:
            print("ASSISTANT: not found in the text.")
            sys.exit()
        # print("output_text:", output_text)
        # print(output_text[0])
        throttle = 0
        steer = 0
        try:
            res = assistant_text.split("\n")
            reason = assistant_text

            for i, text in enumerate(res):
                if "throttle:" in text or "Throttle:" in text:
                    throttle = text.split(":")[1].strip()
                elif "steer:" in text or "Steer:" in text:
                    steer = text.split(":")[1].strip()
                # elif "reason:" in text or "Reason:" in text:
                #     reason += text.split(":")[1].strip()

            if throttle == '':
                throttle = self.last_throttle
            if steer == '':
                steer = self.last_steer
            # print("throttle:", float(throttle))
            # print("steer:", float(steer))
            # print("reason:", reason)
        except:
            throttle = self.last_throttle
            steer = self.last_steer
            if throttle == '':
                throttle = self.last_throttle
            if steer == '':
                steer = self.last_steer
            reason = "format error, set throttle=0, steer=0: " + assistant_text
            # print("throttle:", float(throttle))
            # print("steer:", float(steer))
            # print("reason:", reason)

        # print("statement:", statement)
        # print("throttle:", float(throttle))
        # print("steer:", float(steer))
        # print("reason:", reason)
        # print("------------------time_step:{}--------------------".format(time_step))

        try:
            throttle_res = float(throttle)
            steer_res = float(steer)
        except:
            throttle_res = self.last_throttle
            steer_res = self.last_steer

        self.last_steer = steer_res
        self.last_throttle = throttle_res

        #
        if steer_res > 1.0:
            steer_res = 1.0
        if steer_res < -1.0:
            steer_res = -1.0
        vlm_log_filepath = os.path.join(self.res_dir, "vlm_log.txt")
        eval_log_txt_formatter = "[step]:{step},\n[statement]:{statement},\n[throttle]:{throttle},\n[steer]:{steer}," \
                                 "\n[reason]:{reason}\n----------------\n "
        to_write = eval_log_txt_formatter.format(step=time_step,
                                                 statement=statement,
                                                 throttle=throttle_res,
                                                 steer=steer_res,
                                                 reason=reason
                                                 )

        with open(vlm_log_filepath, "a") as f:
            f.write(to_write)

        return [steer_res, throttle_res]

    def save_Q_action_image(self, obs, action=0, Q=0, time_step=0, path=""):
        rgb_array = obs
        rgb_array = rgb_array.astype(np.uint8)  # 转换为 uint8 类型
        # print("rgb_array:", rgb_array.shape)
        # 将 numpy 数组转换为 PIL 图像
        image = Image.fromarray(rgb_array, 'RGB')

        # 准备要绘制的文本
        text1 = "action:{}".format(str(action))
        text2 = "Q:{:.2f}".format(Q)

        # 选择字体和大小，这里需要一个.ttf字体文件的路径
        font = ImageFont.truetype("/home/xcm/LLM_RL/test_llm/SimSun.ttf", 15)

        # 创建一个可以在给定图像上绘图的对象
        draw = ImageDraw.Draw(image)

        # 设置文本位置
        text1_position = (10, 10)
        text2_position = (10, 30)  # 假设两行文本间隔20像素

        # 绘制文本
        draw.text(text1_position, text1, font=font, fill=(255, 0, 0))
        draw.text(text2_position, text2, font=font, fill=(0, 255, 0))

        # 保存图像
        image.save(path + '/{}.png'.format(time_step))
