#!/usr/bin/env python3

import argparse
import json
import os
import re
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple

import tensorflow as tf
from android_env.proto.a11y import android_accessibility_forest_pb2
from tqdm import tqdm


class AndroidControlProcessor:
    
    ACTION_TYPE_MAP = {
        'click': 4,
        'long_press': 0,
        'scroll': 4,
        'input_text': 3,
        'navigate_home': 6,
        'navigate_back': 5,
        'wait': 1,
        'finish': 10,
    }
    
    SCROLL_DIRECTION_MAP = {'left': 'right', 'right': 'left', 'up': 'down', 'down': 'up'}
    SCROLL_DELTA = 0.3
    
    def __init__(self, input_dir: str, output_dir: str, splits_file: str):
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.splits_file = Path(splits_file)
        self.images_dir = self.output_dir / 'screenshots'
        self.test_output_dir = self.output_dir / 'test' / 'android_control'
        
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.images_dir.mkdir(parents=True, exist_ok=True)
        self.test_output_dir.mkdir(parents=True, exist_ok=True)
    
    def get_interactive_nodes(self, forest) -> List[Dict[str, Any]]:
        interactives = []
        for window in forest.windows:
            nodes = window.tree.nodes
            id_to_node = {getattr(node, 'unique_id', idx): node for idx, node in enumerate(nodes)}
            
            def is_interactive(node):
                return (
                    getattr(node, 'is_clickable', False) or
                    getattr(node, 'is_focusable', False) or
                    any(action.id in [1, 16, 32] for action in node.actions)
                )
            
            def iterative_dfs(root):
                if not root:
                    return
                stack = [root]
                visited = set()
                while stack:
                    node = stack.pop()
                    if id(node) in visited:
                        continue
                    visited.add(id(node))
                    if is_interactive(node):
                        bounds_obj = getattr(node, 'bounds_in_screen', None)
                        bounds = {
                            'left': getattr(bounds_obj, 'left', 0) if bounds_obj else 0,
                            'top': getattr(bounds_obj, 'top', 0) if bounds_obj else 0,
                            'right': getattr(bounds_obj, 'right', 0) if bounds_obj else 0,
                            'bottom': getattr(bounds_obj, 'bottom', 0) if bounds_obj else 0
                        }
                        interactives.append({
                            'unique_id': getattr(node, 'unique_id', None),
                            'class_name': node.class_name,
                            'content_description': node.content_description,
                            'text': node.text,
                            'resource_id': getattr(node, 'view_id_resource_name', 'N/A'),
                            'bounds': bounds
                        })
                    for child_id in reversed(node.child_ids):
                        child_node = id_to_node.get(child_id)
                        if child_node:
                            stack.append(child_node)
            
            all_child_ids = {cid for node in nodes for cid in node.child_ids}
            root_nodes = [node for node in nodes if getattr(node, 'unique_id', None) not in all_child_ids]
            for root in root_nodes:
                iterative_dfs(root)
        return interactives
    
    def extract_tfrecord_data(self) -> str:
        pattern = str(self.input_dir / 'android_control*')
        filenames = tf.io.gfile.glob(pattern)
        if not filenames:
            raise ValueError(f"No TFRecord files found matching pattern: {pattern}")
        
        raw_dataset = tf.data.TFRecordDataset(filenames, compression_type='GZIP')
        dataset_iterator = tf.compat.v1.data.make_one_shot_iterator(raw_dataset)
        
        total_samples = sum(1 for _ in tf.data.TFRecordDataset(filenames, compression_type='GZIP'))
        jsonl_path = self.output_dir / 'data.jsonl'
        
        with open(jsonl_path, 'w', encoding='utf-8') as jsonl_file:
            pbar = tqdm(total=total_samples, desc="Extracting TFRecord data")
            index = 0
            while True:
                try:
                    example = tf.train.Example.FromString(dataset_iterator.get_next().numpy())
                    
                    step_instructions = [d.decode('utf-8') for d in example.features.feature['step_instructions'].bytes_list.value]
                    episode_id = [d for d in example.features.feature['episode_id'].int64_list.value]
                    goal = [d.decode('utf-8') for d in example.features.feature['goal'].bytes_list.value]
                    screenshot_widths = [d for d in example.features.feature['screenshot_widths'].int64_list.value]
                    screenshot_heights = [d for d in example.features.feature['screenshot_heights'].int64_list.value]
                    actions = [d.decode('utf-8') for d in example.features.feature['actions'].bytes_list.value]
                    
                    forests_list = []
                    node_info_list = []
                    for forest_bytes in example.features.feature['accessibility_trees'].bytes_list.value:
                        forest = android_accessibility_forest_pb2.AndroidAccessibilityForest().FromString(forest_bytes)
                        forests_list.append(str(forest))
                        node_info_list.append(self.get_interactive_nodes(forest))
                    
                    screenshot_bytes = example.features.feature['screenshots'].bytes_list.value
                    screenshot_paths = []
                    for img_idx, screenshot_byte in enumerate(screenshot_bytes):
                        episode_dir = self.images_dir / str(episode_id[0])
                        episode_dir.mkdir(parents=True, exist_ok=True)
                        screenshot_path = episode_dir / f'screenshot_{episode_id[0]}_{img_idx}.png'
                        screenshot_path.write_bytes(screenshot_byte)
                        screenshot_paths.append(str(screenshot_path))
                    
                    data = {
                        'index': index,
                        'screenshot_path': screenshot_paths,
                        'accessibility_tree': forests_list,
                        'step_instructions': step_instructions,
                        'episode_id': episode_id,
                        'goal': goal,
                        'screenshot_widths': screenshot_widths,
                        'screenshot_heights': screenshot_heights,
                        'actions': actions,
                        'node_info': node_info_list,
                    }
                    jsonl_file.write(json.dumps(data, ensure_ascii=False) + '\n')
                    index += 1
                    pbar.update(1)
                except tf.errors.OutOfRangeError:
                    break
                except Exception as e:
                    print(f"Error processing sample {index}: {e}")
                    continue
            pbar.close()
        
        return str(jsonl_path)
    
    def transform_action(self, action_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        try:
            action = json.loads(action_data['action'])
            action_type = action['action_type']
        except (json.JSONDecodeError, KeyError) as e:
            print(f"Error parsing action: {e}")
            return None
        
        result = {
            'result_action_type': 2,
            'result_action_text': "",
            'result_lift_yx': [-1.0, -1.0],
            'result_touch_yx': [-1.0, -1.0],
            'duration': None
        }
        
        if action_type not in self.ACTION_TYPE_MAP:
            return None
        
        result['result_action_type'] = self.ACTION_TYPE_MAP[action_type]
        
        if action_type in ['click', 'long_press', 'scroll']:
            try:
                start_x = action['x'] / action_data['screenshot_width']
                start_y = action['y'] / action_data['screenshot_height']
            except (KeyError, ZeroDivisionError):
                start_x = 0.5
                start_y = 0.5
            
            if action_type in ['click', 'long_press']:
                end_x, end_y = start_x, start_y
            elif action_type == 'scroll':
                direction = self.SCROLL_DIRECTION_MAP.get(action.get('direction', 'down'), 'down')
                if direction == 'up':
                    end_x, end_y = start_x, max(0.0, start_y - self.SCROLL_DELTA)
                elif direction == 'down':
                    end_x, end_y = start_x, min(1.0, start_y + self.SCROLL_DELTA)
                elif direction == 'left':
                    end_x, end_y = max(0.0, start_x - self.SCROLL_DELTA), start_y
                else:
                    end_x, end_y = min(1.0, start_x + self.SCROLL_DELTA), start_y
            
            result['result_touch_yx'] = [start_y, start_x]
            result['result_lift_yx'] = [end_y, end_x]
        
        elif action_type == 'input_text':
            result['result_action_text'] = action.get('text', '')
        
        return result
    
    def parse_ui_tree_bounds(self, accessibility_tree: str, screenshot_width: int, screenshot_height: int) -> List[List[float]]:
        bounds_pattern = r'bounds_in_screen \{([^}]+)\}'
        bounds_matches = list(re.finditer(bounds_pattern, accessibility_tree))
        ui_positions = []
        
        for match in bounds_matches:
            bounds_content = match.group(1)
            bounds_dict = {}
            for item in bounds_content.strip().split('\n'):
                item = item.strip()
                if item and ':' in item:
                    try:
                        key, value = item.split(': ', 1)
                        bounds_dict[key] = int(value)
                    except ValueError:
                        continue
            
            y_top = bounds_dict.get('top', 0)
            x_left = bounds_dict.get('left', 0)
            y_bottom = bounds_dict.get('bottom', 0)
            x_right = bounds_dict.get('right', 0)
            
            height = y_bottom - y_top
            width = x_right - x_left
            
            if screenshot_height > 0 and screenshot_width > 0:
                y_norm = y_top / screenshot_height
                x_norm = x_left / screenshot_width
                height_norm = height / screenshot_height
                width_norm = width / screenshot_width
                ui_positions.append([y_norm, x_norm, height_norm, width_norm])
        
        return ui_positions
    
    def build_test_data(self, action_data: Dict[str, Any], transformed_action: Dict[str, Any]) -> Dict[str, Any]:
        ui_positions = self.parse_ui_tree_bounds(
            action_data['ui_trees'],
            action_data['screenshot_width'],
            action_data['screenshot_height']
        )
        
        test_data = {
            'episode_id': action_data['episode_id'],
            'step_id': action_data['step'],
            'episode_length': action_data['episode_length'],
            'image_width': action_data['screenshot_width'],
            'image_height': action_data['screenshot_height'],
            'image_path': action_data['screenshot_path'],
            'instruction': action_data['goal'],
            'result_action_type': transformed_action['result_action_type'],
            'result_touch_yx': str(transformed_action['result_touch_yx']),
            'result_lift_yx': str(transformed_action['result_lift_yx']),
            'duration': transformed_action['duration'],
            'result_action_text': str(transformed_action['result_action_text']),
            'ui_positions': str(ui_positions),
            'low_instruction': action_data['low_instruction'],
            'subset': 'android_control'
        }
        
        return test_data
    
    def load_train_episodes(self) -> set:
        if not self.splits_file.exists():
            print(f"Warning: Splits file not found: {self.splits_file}")
            return set()
        
        try:
            with open(self.splits_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
                return set(data.get('train', []))
        except Exception as e:
            print(f"Error loading train episodes: {e}")
            return set()
    
    def process_episodes(self, jsonl_path: str, train_episodes: set) -> Dict[str, List[int]]:
        episodes = {'in_train': [], 'not_in_train': []}
        
        total_lines = sum(1 for _ in open(jsonl_path, 'r', encoding='utf-8'))
        
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, total=total_lines, desc="Processing episodes"):
                try:
                    data = json.loads(line.strip())
                    episode_id = data['episode_id'][0]
                    
                    if episode_id in train_episodes:
                        episodes['in_train'].append(episode_id)
                        continue
                    
                    episodes['not_in_train'].append(episode_id)
                    
                    episode_data = []
                    goal = data['goal'][0]
                    actions = data['actions']
                    screenshot_widths = data['screenshot_widths']
                    screenshot_heights = data['screenshot_heights']
                    screenshot_paths = data['screenshot_path']
                    
                    for index, action in enumerate(actions):
                        action_data = {
                            'action': action,
                            'screenshot_width': screenshot_widths[index],
                            'screenshot_height': screenshot_heights[index],
                            'screenshot_path': screenshot_paths[index],
                            'low_instruction': data['step_instructions'][index],
                            'goal': goal,
                            'episode_id': episode_id,
                            'step': index,
                            'episode_length': len(screenshot_paths),
                            'ui_trees': data['accessibility_tree'][index]
                        }
                        
                        transformed_action = self.transform_action(action_data)
                        if transformed_action:
                            test_data = self.build_test_data(action_data, transformed_action)
                            episode_data.append(test_data)
                    
                    finish_action_data = {
                        'action': '{"action_type":"finish"}',
                        'screenshot_width': screenshot_widths[-1],
                        'screenshot_height': screenshot_heights[-1],
                        'screenshot_path': screenshot_paths[-1],
                        'goal': goal,
                        'episode_id': episode_id,
                        'step': len(actions),
                        'episode_length': len(screenshot_paths),
                        'ui_trees': data['accessibility_tree'][-1],
                        'low_instruction': 'finish the task'
                    }
                    
                    finish_action = self.transform_action(finish_action_data)
                    if finish_action:
                        finish_data = self.build_test_data(finish_action_data, finish_action)
                        episode_data.append(finish_data)
                    
                    episode_dir = self.test_output_dir / str(episode_id)
                    episode_dir.mkdir(parents=True, exist_ok=True)
                    output_file = episode_dir / f"{episode_id}.json"
                    
                    with open(output_file, 'w', encoding='utf-8') as f:
                        json.dump(episode_data, f, ensure_ascii=False, indent=2)
                
                except Exception as e:
                    print(f"Error processing episode: {e}")
                    continue
        
        return episodes
    
    def run(self):
        print("Step 1: Extracting TFRecord data...")
        jsonl_path = self.extract_tfrecord_data()
        print(f"Extracted data saved to: {jsonl_path}")
        
        print("Step 2: Loading train episode splits...")
        train_episodes = self.load_train_episodes()
        print(f"Loaded {len(train_episodes)} train episodes")
        
        print("Step 3: Processing episodes...")
        episodes = self.process_episodes(jsonl_path, train_episodes)
        
        print(f"Number of test episodes: {len(episodes['not_in_train'])}")
        print(f"Test set saved at: {self.test_output_dir}")


def main():
    parser = argparse.ArgumentParser(description='Process Android Control dataset')
    parser.add_argument('--input_dir', type=str, default='/INPUT_DIR',
                       help='Input directory containing TFRecord files')
    parser.add_argument('--output_dir', type=str, default='/OUTPUT_DIR',
                       help='Output directory for processed data')
    parser.add_argument('--splits_file', type=str, default='/SPLITS_FILE',
                       help='Path to splits.json file')
    
    args = parser.parse_args()
    
    processor = AndroidControlProcessor(
        input_dir=args.input_dir,
        output_dir=args.output_dir,
        splits_file=args.splits_file
    )
    
    processor.run()


if __name__ == '__main__':
    main()
