import os
import time
import copy
import torch
import shutil
from PIL import Image, ImageDraw

from MobileAgent.api import inference_chat
from MobileAgent.text_localization import ocr
from MobileAgent.icon_localization import load_model, det
from MobileAgent.controller import get_screenshot, tap, slide, type, back, home
from MobileAgent.prompt import get_action_prompt, get_reflect_prompt, get_memory_prompt, get_process_prompt
from MobileAgent.chat import init_action_chat, init_reflect_chat, init_memory_chat, add_response, add_response_two_image, print_status

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope import snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
qwen_dir = snapshot_download("qwen/Qwen-VL-Chat-Int4", revision='v1.0.0')
groundingdino_dir = "" # Your groundingdino path
groundingdino_model = load_model('./groundingdino/config/GroundingDINO_SwinT_OGC.py', groundingdino_dir, device="cuda").eval()
ocr_detection = pipeline(Tasks.ocr_detection, model='damo/cv_resnet18_ocr-detection-line-level_damo')
ocr_recognition = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-document_damo')
torch.manual_seed(1234)

adb_path = "" # Your adb path
instruction = "" # Your instruction
token = "" # Your GPT-4 API token

def get_all_files_in_folder(folder_path):
    file_list = []
    for file_name in os.listdir(folder_path):
        file_list.append(file_name)
    return file_list


def draw_coordinates_on_image(image_path, coordinates):
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)
    point_size = 10
    for coord in coordinates:
        draw.ellipse((coord[0] - point_size, coord[1] - point_size, coord[0] + point_size, coord[1] + point_size), fill='red')
    output_image_path = './screenshot/output_image.png'
    image.save(output_image_path)
    return output_image_path


def crop(image, box, i):
    image = Image.open(image)
    x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
    if x1 >= x2-10 or y1 >= y2-10:
        return
    cropped_image = image.crop((x1, y1, x2, y2))
    cropped_image.save(f"./temp/{i}.jpg")


def generate(tokenizer, model, image_file, query):
    query = tokenizer.from_list_format([
        {'image': image_file},
        {'text': query},
    ])
    response, _ = model.chat(tokenizer, query=query, history=None)
    return response


def merge_text_blocks(text_list, coordinates_list):
    merged_text_blocks = []
    merged_coordinates = []

    sorted_indices = sorted(range(len(coordinates_list)), key=lambda k: (coordinates_list[k][1], coordinates_list[k][0]))
    sorted_text_list = [text_list[i] for i in sorted_indices]
    sorted_coordinates_list = [coordinates_list[i] for i in sorted_indices]

    num_blocks = len(sorted_text_list)
    merge = [False] * num_blocks

    for i in range(num_blocks):
        if merge[i]:
            continue
        
        anchor = i
        
        group_text = [sorted_text_list[anchor]]
        group_coordinates = [sorted_coordinates_list[anchor]]

        for j in range(i+1, num_blocks):
            if merge[j]:
                continue

            if abs(sorted_coordinates_list[anchor][0] - sorted_coordinates_list[j][0]) < 10 and \
            sorted_coordinates_list[j][1] - sorted_coordinates_list[anchor][3] >= -10 and sorted_coordinates_list[j][1] - sorted_coordinates_list[anchor][3] < 30 and \
            abs(sorted_coordinates_list[anchor][3] - sorted_coordinates_list[anchor][1] - (sorted_coordinates_list[j][3] - sorted_coordinates_list[j][1])) < 10:
                group_text.append(sorted_text_list[j])
                group_coordinates.append(sorted_coordinates_list[j])
                merge[anchor] = True
                anchor = j
                merge[anchor] = True

        merged_text = "\n".join(group_text)
        min_x1 = min(group_coordinates, key=lambda x: x[0])[0]
        min_y1 = min(group_coordinates, key=lambda x: x[1])[1]
        max_x2 = max(group_coordinates, key=lambda x: x[2])[2]
        max_y2 = max(group_coordinates, key=lambda x: x[3])[3]

        merged_text_blocks.append(merged_text)
        merged_coordinates.append([min_x1, min_y1, max_x2, max_y2])

    return merged_text_blocks, merged_coordinates


def get_perception_infos(adb_path, screenshot_file):
    get_screenshot(adb_path)
    
    width, height = Image.open(screenshot_file).size
    
    text, coordinates = ocr(screenshot_file, ocr_detection, ocr_recognition)
    text, coordinates = merge_text_blocks(text, coordinates)
    
    center_list = [[(coordinate[0]+coordinate[2])/2, (coordinate[1]+coordinate[3])/2] for coordinate in coordinates]
    draw_coordinates_on_image(screenshot_file, center_list)
    
    perception_infos = []
    for i in range(len(coordinates)):
        perception_info = {"text": "text: " + text[i], "coordinates": coordinates[i]}
        perception_infos.append(perception_info)
        
    coordinates = det(screenshot_file, "icon", groundingdino_model)
    
    for i in range(len(coordinates)):
        perception_info = {"text": "icon", "coordinates": coordinates[i]}
        perception_infos.append(perception_info)
        
    image_box = []
    image_id = []
    for i in range(len(perception_infos)):
        if perception_infos[i]['text'] == 'icon':
            image_box.append(perception_infos[i]['coordinates'])
            image_id.append(i)

    for i in range(len(image_box)):
        crop(screenshot_file, image_box[i], image_id[i])

    images = get_all_files_in_folder(temp_file)
    if len(images) > 0:
        images = sorted(images, key=lambda x: int(x.split('/')[-1].split('.')[0]))
        image_id = [int(image.split('/')[-1].split('.')[0]) for image in images]
        icon_map = {}
        prompt = 'This image is an icon from a phone screen. Please describe the color and shape of this icon.'#Please describe this icon.
        for i in range(len(images)):
            image_path = os.path.join(temp_file, images[i])
            icon_width, icon_height = Image.open(image_path).size
            if icon_height > 0.8 * height or icon_width * icon_height > 0.2 * width * height:
                des = "None"
            else:
                des = generate(tokenizer, model, image_path, prompt)
            icon_map[i+1] = des
        for i, j in zip(image_id, range(1, len(image_id)+1)):
            if icon_map.get(j):
                perception_infos[i]['text'] = "icon: " + icon_map[j]

    for i in range(len(perception_infos)):
        perception_infos[i]['coordinates'] = [int((perception_infos[i]['coordinates'][0]+perception_infos[i]['coordinates'][2])/2), int((perception_infos[i]['coordinates'][1]+perception_infos[i]['coordinates'][3])/2)]
        
    return perception_infos, width, height

tokenizer = AutoTokenizer.from_pretrained(qwen_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(qwen_dir, device_map="cuda", trust_remote_code=True,use_safetensors=True).eval()
model.generation_config = GenerationConfig.from_pretrained(qwen_dir, trust_remote_code=True, do_sample=False)

add_info = "If you want to tap an icon of an app, use the action \"Open app\""
thought_history = []
summary_history = []
action_history = []
summary = ""
action = ""
completed_requirements = ""
memory = ""
insight = ""
temp_file = "temp"
screenshot = "screenshot"
if not os.path.exists(temp_file):
    os.mkdir(temp_file)
if not os.path.exists(screenshot):
    os.mkdir(screenshot)
error_flag = False

while True:
    this_results = {}
    screenshot_file = "./screenshot/screenshot.jpg"
    perception_infos, width, height = get_perception_infos(adb_path, screenshot_file)
    shutil.rmtree(temp_file)
    os.mkdir(temp_file)
    
    keyboard = False
    for perception_info in perception_infos:
        if perception_info['coordinates'][1] < 0.95 * height:
            continue
        if 'ADB Keyboard' in perception_info['text']:
            keyboard = True
            break

    prompt_action = get_action_prompt(instruction, perception_infos, width, height, keyboard, summary_history, action_history, summary, action, add_info, error_flag, completed_requirements, memory)
    chat_action = init_action_chat()
    chat_action = add_response("user", prompt_action, chat_action, screenshot_file)

    output_action = inference_chat(chat_action, 'gpt-4-vision-preview', token)
    this_results['decision'] = output_action
    thought = output_action.split("### Thought ###")[-1].split("### Action ###")[0].replace("\n", " ").replace(":", "").replace("  ", " ").strip()
    summary = output_action.split("### Operation ###")[-1].replace("\n", " ").replace("  ", " ").strip()
    action = output_action.split("### Action ###")[-1].split("### Operation ###")[0].replace("\n", " ").replace("  ", " ").strip()
    chat_action = add_response("assistant", output_action, chat_action)
    
    prompt_memory = get_memory_prompt(insight)
    chat_action = add_response("user", prompt_memory, chat_action)
    output_memory = inference_chat(chat_action, 'gpt-4-vision-preview', token)
    chat_action = add_response("assistant", output_memory, chat_action)
    
    output_memory = output_memory.split("### Important content ###")[-1].split("\n\n")[0].strip() + "\n"
    if "None" not in output_memory and output_memory not in memory:
        memory += output_memory
        this_results['memory'] = output_memory
    print_status(chat_action)
    
    if "Open app" in action:
        app_name = action.split("(")[-1].split(")")[0]
        text, coordinate = ocr(screenshot_file, ocr_detection, ocr_recognition)
        tap_coordinate = [0, 0]
        for ti in range(len(text)):
            if app_name == text[ti]:
                name_coordinate = [int((coordinate[ti][0] + coordinate[ti][2])/2), int((coordinate[ti][1] + coordinate[ti][3])/2)]
                tap(adb_path, name_coordinate[0], name_coordinate[1]- int(coordinate[ti][3] - coordinate[ti][1]))# 
    
    elif "Tap" in action:
        coordinate = action.split("(")[-1].split(")")[0].split(", ")
        x, y = int(coordinate[0]), int(coordinate[1])
        tap(adb_path, x, y)
    
    elif "Swipe" in action:
        coordinate1 = action.split("Swipe (")[-1].split("), (")[0].split(", ")
        coordinate2 = action.split("), (")[-1].split(")")[0].split(", ")
        x1, y1 = int(coordinate1[0]), int(coordinate1[1])
        x2, y2 = int(coordinate2[0]), int(coordinate2[1])
        slide(adb_path, x1, y1, x2, y2)
        
    elif "Type" in action:
        if "(text)" not in action:
            text = action.split("(")[-1].split(")")[0]
        else:
            text = action.split(" \"")[-1].split("\"")[0]
        type(adb_path, text)
    
    elif "Back" in action:
        back(adb_path)
    
    elif "Home" in action:
        home(adb_path)
        
    elif "Stop" in action:
        break
    
    time.sleep(1)
    
    last_perception_infos = copy.deepcopy(perception_infos)
    last_screenshot_file = "./screenshot/last_screenshot.jpg"
    last_keyboard = keyboard
    if os.path.exists(last_screenshot_file):
        os.remove(last_screenshot_file)
    os.rename(screenshot_file, last_screenshot_file)
    
    perception_infos, width, height = get_perception_infos(adb_path, screenshot_file)
    shutil.rmtree(temp_file)
    os.mkdir(temp_file)
    
    keyboard = False
    for perception_info in perception_infos:
        if perception_info['coordinates'][1] < 0.95 * height:
            continue
        if 'ADB Keyboard' in perception_info['text']:
            keyboard = True
            break
    
    prompt_reflect = get_reflect_prompt(instruction, last_perception_infos, perception_infos, width, height, last_keyboard, keyboard, summary, action, add_info)
    chat_reflect = init_reflect_chat()
    chat_reflect = add_response_two_image("user", prompt_reflect, chat_reflect, [last_screenshot_file, screenshot_file])

    output_reflect = inference_chat(chat_reflect, 'gpt-4-vision-preview', token)
    this_results['reflect'] = output_reflect
    reflect = output_reflect.split("### Answer ###")[-1].replace("\n", " ").strip()
    chat_reflect = add_response("assistant", output_reflect, chat_reflect)
    print_status(chat_reflect)
    
    if 'A' in reflect:
        thought_history.append(thought)
        summary_history.append(summary)
        action_history.append(action)
        
        prompt_memory = get_process_prompt(instruction, thought_history, summary_history, action_history, completed_requirements, add_info)
        chat_memory = init_memory_chat()
        chat_memory = add_response("user", prompt_memory, chat_memory)
        output_memory = inference_chat(chat_memory, 'gpt-4-1106-preview', token)
        chat_memory = add_response("assistant", output_memory, chat_memory)
        print_status(chat_memory)
        
        completed_requirements = output_memory.split("### Completed contents ###")[-1].replace("\n", " ").strip()
        this_results['process'] = output_memory
        
        error_flag = False
    
    elif 'B' in reflect:
        error_flag = True
        back(adb_path)
        
    elif 'C' in reflect:
        error_flag = True
        
    os.remove(last_screenshot_file)