from PIL import Image, ImageFont, ImageDraw
import numpy as np
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import os

class VLM_Agent(object):
    def __init__(self, task="ghost_static", device="cuda:0", res_dir=""):
        self.task = task
        self.device = device
        self.res_dir = res_dir
        model_id = "Qwen/Qwen2-VL-2B-Instruct"
        # default: Load the model on the available device(s)
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_id, torch_dtype="auto", device_map=self.device
        )
        self.processor = AutoProcessor.from_pretrained(model_id)
        #
        self.last_steer = 0
        self.last_throttle = 0
        self.last_ego_orientation = 0
        self.first_time = True

    def select_action(self, obs, time_step=0, seed=0):
        rgb_array, selected_ego_velocity, ego_orientation = obs
        rgb_array = rgb_array.astype(np.uint8)  # 转换为 uint8 类型
        # print("rgb_array:", rgb_array.shape)
        # 将 numpy 数组转换为 PIL 图像
        image = Image.fromarray(rgb_array, 'RGB')
        # image.save("/home/user/cg/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/datasets/xcm_nature/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/home/xcm/LLM_RL/images/{}/{}.jpg".format(seed, time_step))

        if self.task == "ghost_static":
            if self.first_time is True:
                self.last_ego_orientation = ego_orientation
                self.first_time = False
            ego_orientation = ego_orientation - self.last_ego_orientation  # 校准
            statement = "The scene where ego vehicles are located is the ghost probe scene, that is, pedestrians suddenly cross the road. Please describe the content of the image, determine if there is a pedestrian likely to appear, or cross the road, if there is, and the distance is close, ego need to consider adjusting action to avoid collision. " \
                        "It should be noted that the throttle range is [-1, 1], the throttle <=0 indicates slowing down, and the throttle >0 indicates different acceleration. Steering ranges from -1 to 1, with steering >0 indicating a right turn, steering <0 indicating a left turn, and steering=0 indicating straight driving. " \
                        "In addition to the image, ego's current speed is {:.2f}m/s, yaw is {:.2f}. If there is no visible pedestrian in the image or the distance of the pedestrian is relatively far, ego should set the throttle =1.0 to accelerate. " \
                        "If ego's direction deviates from the current lane, steering adjustment is required. Finally, design reasonable throttle and steering output " \
                        "values for the ego based on the description of the image, the current speed and yaw. Please follow the following example to output content:\
                        \nDescribe:xxx\
                        \nthrottle:xxx\
                        \nsteer:xxx".format(selected_ego_velocity, ego_orientation)

        else:  # highway
            statement = "The scene where ego is located is a high-speed scene in automatic driving. Ego needs to " \
                        "avoid collisions on highways, try to stay in the same lane. Please describe the image in detail to determine whether there is a mmediate danger of collision. If so, ego needs to consider slowing down (set throttle <0). Furthermore, please describe the image in detail to determine if the ego is going through a bend. If so, the ego's steering needs to be adjusted. " \
                        "It should be noted that the throttle range is -1 ~ 1, the throttle <=0 indicates slowing down, and the throttle >0 indicates different acceleration. Steering ranges from -1 to 1, with steering >0 indicating a right turn, steering <0 indicating a left turn, and steering=0 indicating straight driving. " \
                        "In addition to the image, ego's current speed is {:.2f}m/s, yaw is {:.2f}. If there is no immediate danger of collision, ego can set throttle=1 to accelerate. Finally, design reasonable throttle and steering output " \
                        "values for the ego based on the description of the image, the current speed and yaw. Please follow the following template to output content:\
                        \nDescribe:xxx\
                        \nthrottle:xxx\
                        \nsteer:xxx".format(selected_ego_velocity, ego_orientation)

        messages = [
            {
                "role": "system",
                "content": "You are an AI assistant integrated into an autonomous driving system in Carla. Your task is to "
                        "analyze the RGB camera image taken from the front of the ego vehicle and output throttle and "
                        "steering values."
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image
                    },
                    {"type": "text", "text": statement},
                ],
            },
        ]

        # Preprocess the inputs
        text = self.processor.apply_chat_template(
            messages, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.device)

        # Inference: Generation of the output
        # generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.8)
        generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        # print("statement:", statement)
        # print(output_text[0])
        throttle = 0
        steer = 0
        try:
            res = output_text[0].split("\n")
            reason = output_text[0]

            for i, text in enumerate(res):
                if "throttle:" in text or "Throttle:" in text:
                    throttle = text.split(":")[1].strip()
                elif "steer:" in text or "Steer:" in text:
                    steer = text.split(":")[1].strip()
                # elif "reason:" in text or "Reason:" in text:
                #     reason += text.split(":")[1].strip()

            if throttle == '':
                throttle = self.last_throttle
            if steer == '':
                steer = self.last_steer
            # print("throttle:", float(throttle))
            # print("steer:", float(steer))
            # print("reason:", reason)
        except:
            throttle = self.last_throttle
            steer = self.last_steer
            if throttle == '':
                throttle = self.last_throttle
            if steer == '':
                steer = self.last_steer
            reason = "format error, set throttle=0, steer=0: " + output_text[0]
            # print("throttle:", float(throttle))
            # print("steer:", float(steer))
            # print("reason:", reason)

        # print("statement:", statement)
        # print("throttle:", float(throttle))
        # print("steer:", float(steer))
        # print("reason:", reason)
        # print("------------------time_step:{}--------------------".format(time_step))

        try:
            throttle_res = float(throttle)
            steer_res = float(steer)
        except:
            throttle_res = self.last_throttle
            steer_res = self.last_steer

        self.last_steer = steer_res
        self.last_throttle = throttle_res

        #
        if steer_res > 1.0:
            steer_res = 1.0
        if steer_res < -1.0:
            steer_res = -1.0

        if throttle_res > 1.0:
            throttle_res = 1.0
        if throttle_res < -1.0:
            throttle_res = -1.0

        vlm_log_filepath = os.path.join(self.res_dir, "vlm_log.txt")
        eval_log_txt_formatter = "[step]:{step},\n[statement]:{statement},\n[throttle]:{throttle},\n[steer]:{steer}," \
                                 "\n[reason]:{reason}\n----------------\n "
        to_write = eval_log_txt_formatter.format(step=time_step,
                                                 statement=statement,
                                                 throttle=throttle_res,
                                                 steer=steer_res,
                                                 reason=reason
                                                 )

        with open(vlm_log_filepath, "a") as f:
            f.write(to_write)

        return [steer_res, throttle_res]

    def save_Q_action_image(self, obs, action=0, Q=0, time_step=0, path=""):
        rgb_array = obs
        rgb_array = rgb_array.astype(np.uint8)  # 转换为 uint8 类型
        # print("rgb_array:", rgb_array.shape)
        # 将 numpy 数组转换为 PIL 图像
        image = Image.fromarray(rgb_array, 'RGB')

        # 准备要绘制的文本
        text1 = "action:{}".format(str(action))
        text2 = "Q:{:.2f}".format(Q)

        # 选择字体和大小，这里需要一个.ttf字体文件的路径
        font = ImageFont.truetype("/home/xcm/LLM_RL/test_llm/SimSun.ttf", 15)

        # 创建一个可以在给定图像上绘图的对象
        draw = ImageDraw.Draw(image)

        # 设置文本位置
        text1_position = (10, 10)
        text2_position = (10, 30)  # 假设两行文本间隔20像素

        # 绘制文本
        draw.text(text1_position, text1, font=font, fill=(255, 0, 0))
        draw.text(text2_position, text2, font=font, fill=(0, 255, 0))

        # 保存图像
        image.save(path + '/{}.png'.format(time_step))
