import json
import tensorflow as tf
import base64
from PIL import Image
import io
from tqdm import tqdm
from android_env.proto.a11y import android_accessibility_forest_pb2
import copy


def node_to_dict(node):
    return {
        'element_type': node.class_name,
        'textual_attributes': {
            'text': node.text,
            'content_description': node.content_description,
            'hint_text': node.hint_text,
            'tooltip_text': node.tooltip_text,
            'view_id_resource_name': node.view_id_resource_name
        },
        'location_and_size': {
            'bounds_in_screen': {
                'left': node.bounds_in_screen.left,
                'top': node.bounds_in_screen.top,
                'right': node.bounds_in_screen.right,
                'bottom': node.bounds_in_screen.bottom
            }
        },
        'element_status': {
            'is_checked': node.is_checked,
            'is_enabled': node.is_enabled,
            'is_focused': node.is_focused,
            'is_selected': node.is_selected
        },
        'element_properties': {
            'is_checkable': node.is_checkable,
            'is_clickable': node.is_clickable,
            'is_editable': node.is_editable,
            'is_focusable': node.is_focusable,
            'is_long_clickable': node.is_long_clickable,
            'is_scrollable': node.is_scrollable,
            'is_password': node.is_password,
            'is_visible_to_user': node.is_visible_to_user
        }
    }

def parse_example_to_dict(example):
    result_dict = {}


    result_dict['episode_id'] = example.features.feature['episode_id'].int64_list.value[0]


    result_dict['goal'] = example.features.feature['goal'].bytes_list.value[0].decode('utf-8')


    nodes_info = []
    forests = [android_accessibility_forest_pb2.AndroidAccessibilityForest().FromString(i) for i in example.features.feature['accessibility_trees'].bytes_list.value]
    result_dict['accessibility_trees'] = []

    result_dict['screenshots'] = [f"androidcontrol_test/android_control_episode_{str(result_dict['episode_id'])}_{str(i)}.png" for i in range(len(forests))]

    accessibility_trees = []
    for forest in forests:
        nodes_info = []
        for window in forest.windows:
            for node in window.tree.nodes:
                node_info = node_to_dict(node)
                nodes_info.append(node_info)
        result_dict['accessibility_trees'].append(nodes_info)
    

    result_dict['screenshot_widths'] = list(example.features.feature['screenshot_widths'].int64_list.value)
    result_dict['screenshot_heights'] = list(example.features.feature['screenshot_heights'].int64_list.value)

    result_dict['actions'] = [
        action.decode('utf-8') for action in example.features.feature['actions'].bytes_list.value
    ]


    result_dict['step_instructions'] = [
        instruction.decode('utf-8') for instruction in example.features.feature['step_instructions'].bytes_list.value
    ]
    
    return result_dict

with open("test_subsplits.json", 'r') as f:
    test_dict = json.load(f)

test_id = []
for key, value in test_dict.items():
    test_id += value

test_id = list(set(test_id))

raw_input_dir = ""

filenames = tf.io.gfile.glob(f'{raw_input_dir}/android_control*')
raw_dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP')
dataset_iterator = tf.compat.v1.data.make_one_shot_iterator(raw_dataset)

fout = open("androidcontrol_test_parsed.json", 'w', encoding='utf-8')
test_parsed_data = []

for data in tqdm(dataset_iterator):
    example = tf.train.Example.FromString(data.numpy())
    episode_id = example.features.feature['episode_id'].int64_list.value
    if episode_id[0] not in test_id:
        continue
    else:
        episode_id = episode_id[0]
        screenshots = example.features.feature['screenshots'].bytes_list.value
        print(episode_id)
        episode_data_dict = parse_example_to_dict(example)
        fout.write(json.dumps(episode_data_dict, ensure_ascii=False) + '\n')

        for i, image_data in enumerate(screenshots):
            image = Image.open(io.BytesIO(image_data))
            image.save(f"androidcontrol_test/android_control_episode_{str(episode_id)}_{str(i)}.png")

