import json
import re
from tqdm import tqdm
import copy
import os

def main():
    collect_gpt4o_instruction(
        input_path="android_control_test_gpt4omini.jsonl",
        output_path="explore_ui_eight_item_gpt4omini.json"
    )
    collect_width_benchmark(
        input_data_path="explore_ui_eight_item_gpt4omini.json",
        location_data_path="androidcontrol_test_parsed_is_clickable_ui.jsonl",
        output_path="exploreBenchmark_qwen_inference_input.json"
    )
    collect_width_benchmark_ValidatedData(
        input_data_path = "android_control_test_gpt4omini_width_checkingcorrect.jsonl",
        output_path = "exploreBenchmark_width_correct.json"
    )
    collect_depth_benchmark(
        input_data_path="explore_ui_eight_item_gpt4omini.json",
        output_path="explore_ui_eight_item_gpt4omini_depth2.json",
        depth=2
    )
    collect_depth_benchmark_ValidatedData(
        input_data_path="android_control_test_gpt4omini_depth2_checkingcorrect.jsonl",
        output_path="exploreBenchmark_depth2.json"
    )
    collect_Qwen2d5_action_answer(
        golden_action_data_path="android_control_test_Qwen2d5_action_answer.jsonl",
        location_data_path="androidcontrol_test_parsed_is_clickable_ui.jsonl",
        output_path="exploreBenchmark_qwen_ValidUIComponents.json"
    )
    collect_original_instruction(
        input_path="android_control_test_simple_input.json",
        output_path="origianlBenchmark.json"
    )


def collect_original_instruction(input_path, output_path):
    with open(input_path, 'r') as f:
        data = json.load(f)
    
    originalBenchamrk_data = []
    cnt = 0
    for item in data:
        new_item = {}
        new_item["idx"] = cnt
        new_item["episode_id"] = item["episode_id"]
        new_item["img_filename"] = item["img_filename"]
        new_item["width"] = item["width"]
        new_item["height"] = item["height"]
        new_item["high_level_instruction"] = item["total_goal"]
        new_item["low_level_instruction"] = item["step_instruction"]
        new_item["action"] = item["action"]
        cnt += 1
        originalBenchamrk_data.append(new_item)
    print(len(originalBenchamrk_data))
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(originalBenchamrk_data, f, indent=4, ensure_ascii=False)

def collect_gpt4o_instruction(input_path, output_path):
    data = []
    for ln in open(input_path, encoding="utf-8"):
        item = json.loads(ln)
        gpt_output = item["gpt_output"]
        pattern = r"```json\s*(.*?)```"
        matches = re.findall(pattern, gpt_output, re.DOTALL)
        assert len(matches) != 0

        item["gpt_output_split"] = matches
        data.append(item)
    print(len(data))
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
        

def collect_depth_benchmark(input_data_path, output_path, depth):
    with open(input_data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    output = []
    cnt = 0
    for idx, item in enumerate(data):
        if idx+depth-1 < len(data) and data[idx+depth-1]["episode_id"] == item["episode_id"]:
            if isinstance(data[idx+depth-1]["gpt_output_split"][0], list):
                gpt_outputs = data[idx+depth-1]["gpt_output_split"][0]
            else:
                gpt_outputs = data[idx+depth-1]["gpt_output_split"]
            for gpt_output in gpt_outputs:
                new_item = {}
                new_item["idx"] = cnt
                new_item["episode_id"] = item["episode_id"]
                new_item["img_filename"] = item["img_filename"].split('./')[-1]
                new_item["width"] = item["width"]
                new_item["height"] = item["height"]
                new_item["high_level_instruction"] = gpt_output["High-Level-Instruction"]
                new_item["low_level_instruction"] = item["step_instruction"]
                new_item["action"] = item["action"]
                output.append(new_item)
                cnt += 1
    print(f"The number of deep exploration degree benchmarks: {len(output)}")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output, f, indent=4)
    

def collect_width_benchmark_ValidatedData(input_data_path, output_path):
    depthBenchmark_ValidatedData = []
    cnt = 0
    for ln in tqdm(open(input_data_path)):
        try:
            item = json.loads(ln)
        except:
            continue
        try:
            if json.loads(item["gpt_output"])["Correct"] == "Yes":
                new_item = {}
                new_item["idx"] = cnt
                new_item["episode_id"] = item["episode_id"]
                new_item["img_filename"] = item["img_filename"].split('./')[-1]
                new_item["width"] = item["width"]
                new_item["height"] = item["height"]
                new_item["high_level_instruction"] = item["high_level_instruction"]
                new_item["low_level_instruction"] = item["low_level_instruction"]
                new_item["action"] = item["action"]
                depthBenchmark_ValidatedData.append(new_item)
                cnt += 1
        except:
            continue
    print(f"The number of deep exploration degree benchmarks: {len(depthBenchmark_ValidatedData)}")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(depthBenchmark_ValidatedData, f, indent=4)


def collect_depth_benchmark_ValidatedData(input_data_path, output_path):
    depthBenchmark_ValidatedData = []
    cnt = 0
    for ln in tqdm(open(input_data_path)):
        try:
            item = json.loads(ln)
        except:
            continue
        if json.loads(item["gpt_output"])["Correct"] == "Yes":
            new_item = {}
            new_item["idx"] = cnt
            new_item["episode_id"] = item["episode_id"]
            new_item["img_filename"] = item["img_filename"].split('./')[-1]
            new_item["width"] = item["width"]
            new_item["height"] = item["height"]
            new_item["high_level_instruction"] = item["high_level_instruction"]
            new_item["low_level_instruction"] = item["low_level_instruction"]
            new_item["action"] = item["action"]
            depthBenchmark_ValidatedData.append(new_item)
            cnt += 1
    print(f"The number of deep exploration degree benchmarks: {len(depthBenchmark_ValidatedData)}")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(depthBenchmark_ValidatedData, f, indent=4)



def collect_width_benchmark(input_data_path, location_data_path, output_path):
    box_location_dict_index_by_img = {}
    for ln in open(location_data_path, encoding='utf-8'):
        item = json.loads(ln)
        accessibility_trees = item["accessibility_trees"]

        for idx, accessibility_tree in enumerate(accessibility_trees):
            accessibility_tree_ui_list = []
            for ui_item in accessibility_tree:
                ui_location = ui_item["location_and_size"]["bounds_in_screen"]
                accessibility_tree_ui_list.append([ui_location["left"], ui_location["top"], ui_location["right"], ui_location["bottom"]])
            
            box_location_dict_index_by_img[item["screenshots"][idx]] = copy.deepcopy(accessibility_tree_ui_list)

    
    exploreBenchmark_data = []

    with open(input_data_path, 'r') as f:
        data = json.load(f)
    for item in data:
        if isinstance(item["gpt_output_split"][0], list):
            gpt_outputs = item["gpt_output_split"][0]
        else:
            gpt_outputs = item["gpt_output_split"]
        for i in range(len(gpt_outputs)):
            new_item = {}
            new_item["episode_id"] = item["episode_id"]
            new_item["img_filename"] = item["img_filename"].split('./')[-1]
            new_item["width"] = item["width"]
            new_item["height"] = item["height"]
            new_item["high_level_instruction"] = gpt_outputs[i]["High-Level-Instruction"]
            new_item["low_level_instruction"] = gpt_outputs[i]["Sub-Instruction"]
            new_item["ui_item"] = gpt_outputs[i]["UI item"]
            
            exploreBenchmark_data.append(new_item)
    print(len(exploreBenchmark_data))
    with open(output_path, 'w') as f:
        json.dump(exploreBenchmark_data, f, indent=4)
            
def collect_Qwen2d5_action_answer(golden_action_data_path, location_data_path, output_path):
    box_location_dict_index_by_img = {}
    for ln in open(location_data_path, encoding='utf-8'):
        item = json.loads(ln)
        accessibility_trees = item["accessibility_trees"]

        for idx, accessibility_tree in enumerate(accessibility_trees):
            accessibility_tree_ui_list = []
            for ui_item in accessibility_tree:
                ui_location = ui_item["location_and_size"]["bounds_in_screen"]
                accessibility_tree_ui_list.append([ui_location["left"], ui_location["top"], ui_location["right"], ui_location["bottom"]])
            
            box_location_dict_index_by_img[item["screenshots"][idx]] = copy.deepcopy(accessibility_tree_ui_list)

    
    exploreBenchmark_data = []
    instruction_idx = 0
    for ln in tqdm(open(golden_action_data_path, encoding='utf-8')):
        item = json.loads(ln)
        
        new_item = {}
        new_item["idx"] = instruction_idx
        new_item["episode_id"] = item["episode_id"]
        new_item["img_filename"] = item["img_filename"].split('./')[-1]
        new_item["width"] = item["width"]
        new_item["height"] = item["height"]
        new_item["high_level_instruction"] = item["high_level_instruction"]
        new_item["low_level_instruction"] = item["low_level_instruction"]
        if "infeasible" in item["Qwen2d5_72B"] or "successful" in item["Qwen2d5_72B"]:
            continue
        try:
            new_item["action"] = json.loads(item["Qwen2d5_72B"].split("actions:")[-1])
            i = new_item["action"]["action_type"]
        except:
            continue
        if "ui" in new_item["action"].keys():
            try:
                new_item["action"]["location"] = box_location_dict_index_by_img[new_item["img_filename"]][new_item["action"]["ui"]]
            except:
                
                continue
        exploreBenchmark_data.append(new_item)
        instruction_idx += 1
    print(len(exploreBenchmark_data))
    with open(output_path, 'w') as f:
        json.dump(exploreBenchmark_data, f, indent=4)


if __name__ == "__main__":
    main()