import sys

from PIL import Image, ImageFont, ImageDraw
import numpy as np
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import os
import json

class VLM_Agent(object):
    def __init__(self, task="ghost_static", device="cuda:0", res_dir=""):
        self.task = task
        self.device = device
        self.res_dir = res_dir
        model_id = "Qwen/Qwen2-VL-7B-Instruct"
        # default: Load the model on the available device(s)
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_id, torch_dtype="auto", device_map=self.device
        )
        self.processor = AutoProcessor.from_pretrained(model_id)
        #
        if self.task == "CarRacing":
            self.last_action = "[0, 0, 0]"
        else:
            self.last_action = "0"
        self.first_time = True

    def select_action(self, obs, time_step=0, seed=0):
        rgb_array = obs.astype(np.uint8)  # 转换为 uint8 类型
        image = Image.fromarray(rgb_array, 'RGB')
        # image.save("/home/cg/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/home/user/cg/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/datasets/xcm_nature/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        # image.save("/home/xcm/LLM_RL/images/{}/{}.jpg".format(seed, time_step))
        if self.task == "CarRacing":
            statement = """
                        Task: Control a car to drive on a racing track as fast as possible without going off the track.  
                        Input:  
                        - Image: The current state of the car, showing the track, car position, and direction.  
                        Output:  
                        - Action: Provide the 3D continuous action to apply to the car. The action should be in the following ranges:  
                          - action[0]: Steering (left/right), range [-1, 1], where -1 is full left and 1 is full right.  
                          - action[1]: Throttle (acceleration), range [0, 1], where 0 is no acceleration and 1 is full acceleration.  
                          - action[2]: Brake (deceleration), range [0, 1], where 0 is no braking and 1 is full braking.  
                
                        Please output the action to apply to the car based on the image information.
                        You need to explain the reason for your chosen actions. 
                        The format of your output is: \ndescription:xxx\naction:xxx\nreason:xxx
                    """

            messages = [
                {
                    "role": "system",
                    "content": "You are a control system for a car in a racing game. Your task is to provide the correct action to control the car "
                               "based on the image input."
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image
                        },
                        {"type": "text", "text": statement},
                    ],
                },
            ]
        elif self.task == "cartpole":
            statement = """
                    Task: Control an inverted pendulum to keep it upright.  
                    Input:  
                    - Image: The current state of the pendulum, showing the rod's position and angle.  
                    Output:  
                    - Action: Provide the 1D continuous action to apply to the pendulum. The action should be in the range [-1, 1], where:  
                      - -1: Apply maximum force to the left.  
                      - 1: Apply maximum force to the right.  
                    Example:  
                    - If the pendulum is leaning to the left, the action could be 0.5 (push to the right).  
                    - If the pendulum is leaning to the right, the action could be -0.5 (push to the left).  
            
                    Please output the action to apply to the pendulum based on the image information. You need to explain the reason for your chosen actions. 
                    The format of your output is: \ndescription:xxx\naction:xxx\nreason:xxx
                    """

            messages = [
                {
                    "role": "system",
                    "content": "You are a control system for an inverted pendulum. "
                               "Your task is to provide the correct action to keep the pendulum upright based on the current image input."
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image
                        },
                        {"type": "text", "text": statement},
                    ],
                },
            ]
        elif self.task == "cartpole_swingup":
            statement = """
                        Task: Control a cartpole to swing it from a hanging position to an upright position.  
                        Input:  
                        - Image: The current state of the cartpole, showing the pole's position and angle.  
                        Output:  
                        - Action: Provide the 1D continuous action to apply to the cartpole. The action should be in the range [-1, 1], where:  
                          - -1: Apply maximum force to the left.  
                          - 1: Apply maximum force to the right.  
                        Example:  
                        - If the pole is leaning to the left, the action could be 0.5 (push to the right).  
                        - If the pole is leaning to the right, the action could be -0.5 (push to the left).  
                        
                        Please output the action to apply to the cartpole based on the image information.
                    """

            messages = [
                {
                    "role": "system",
                    "content": "You are a control system for a cartpole swingup task. "
                               "Your task is to provide the correct action to swing the pole upright based on the current image input."
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image
                        },
                        {"type": "text", "text": statement},
                    ],
                },
            ]


        else:
            sys.exit()

        # Preprocess the inputs
        text = self.processor.apply_chat_template(
            messages, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.device)

        # Inference: Generation of the output
        # generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.8)
        generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        # print("statement:", statement)
        # print(output_text[0])
        action = ""
        try:
            res = output_text[0].split("\n")
            reason = output_text[0]

            for i, text in enumerate(res):
                if "action:" in text or "Action:" in text:
                    action = text.split(":")[1].strip()
                elif "reason:" in text or "Reason:" in text:
                    reason += text.split(":")[1].strip()

            if action == '':
                action = self.last_action

            if action == 'action[0]':
                action = '[0, 0, 0]'

        except:
            action = self.last_action
            if action == '':
                action = self.last_action

            reason = "format error, action = self.last_action: " + output_text[0]

        try:
            if self.task == "CarRacing":
                s = action
                s = s.replace("action:", "")
                action_res = json.loads(s)
            else:
                action_res = float(action)

        except:
            action = self.last_action
            if self.task == "CarRacing":
                s = action
                print("action:", action)
                s = s.replace("action:", "")
                action_res = json.loads(s)
            else:
                action_res = float(action)

        self.last_action = action

        # steer_res = action_res[0]
        # throttle_res = action_res[1]
        # brake_res = action_res[2]

        vlm_log_filepath = os.path.join(self.res_dir, "vlm_log.txt")
        eval_log_txt_formatter = "[step]:{step},\n[statement]:{statement},\n[action]:{action}," \
                                 "\n[reason]:{reason}\n----------------\n "
        to_write = eval_log_txt_formatter.format(step=time_step,
                                                 statement=statement,
                                                 action=action,
                                                 reason=reason
                                                 )

        with open(vlm_log_filepath, "a") as f:
            f.write(to_write)

        return action_res

    def save_Q_action_image(self, obs, action=0, Q=0, time_step=0, path=""):
        rgb_array = obs
        rgb_array = rgb_array.astype(np.uint8)  # 转换为 uint8 类型
        # print("rgb_array:", rgb_array.shape)
        # 将 numpy 数组转换为 PIL 图像
        image = Image.fromarray(rgb_array, 'RGB')

        # 准备要绘制的文本
        text1 = "action:{}".format(str(action))
        text2 = "Q:{:.2f}".format(Q)

        # 选择字体和大小，这里需要一个.ttf字体文件的路径
        font = ImageFont.truetype("/home/xcm/LLM_RL/test_llm/SimSun.ttf", 15)

        # 创建一个可以在给定图像上绘图的对象
        draw = ImageDraw.Draw(image)

        # 设置文本位置
        text1_position = (10, 10)
        text2_position = (10, 30)  # 假设两行文本间隔20像素

        # 绘制文本
        draw.text(text1_position, text1, font=font, fill=(255, 0, 0))
        draw.text(text2_position, text2, font=font, fill=(0, 255, 0))

        # 保存图像
        image.save(path + '/{}.png'.format(time_step))
