# Copyright 2025 The android_world Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs a single task.

The minimal_run.py module is used to run a single task, it is a minimal version
of the run.py module. A task can be specified, otherwise a random task is
selected.
"""

from collections.abc import Sequence
import os
import random
from typing import Type

from absl import app
from absl import flags
from absl import logging
from android_world import registry
from android_world.agents import infer
from android_world.agents import t3a,m3a,UI_TARS_M3A
from android_world.env import env_launcher, adb_utils
from android_world.task_evals import task_eval
import config
from task_explorer.exploration_and_mining import auto_exploration, manually_exploration
from task_mapping import get_app_info, TASK_LIST, TASK_APP_MAPPING
import json

logging.set_verbosity(logging.WARNING)

os.environ['GRPC_VERBOSITY'] = 'ERROR'  # Only show errors
os.environ['GRPC_TRACE'] = 'none'  # Disable tracing


def _find_adb_directory() -> str:
  """Returns the directory where adb is located."""
  potential_paths = [
      os.path.expanduser('~/Library/Android/sdk/platform-tools/adb'),
      os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
  ]
  for path in potential_paths:
    if os.path.isfile(path):
      return path
  raise EnvironmentError(
      'adb not found in the common Android SDK paths. Please install Android'
      " SDK and ensure adb is in one of the expected directories. If it's"
      ' already installed, point to the installed location.'
  )


_ADB_PATH = flags.DEFINE_string(
    'adb_path',
    _find_adb_directory(),
    'Path to adb. Set if not installed through SDK.',
)
_EMULATOR_SETUP = flags.DEFINE_boolean(
    'perform_emulator_setup',
    False,
    'Whether to perform emulator setup. This must be done once and only once'
    ' before running Android World. After an emulator is setup, this flag'
    ' should always be False.',
)
_DEVICE_CONSOLE_PORT = flags.DEFINE_integer(
    'console_port',
    5554,
    'The console port of the running Android device. This can usually be'
    ' retrieved by looking at the output of `adb devices`. In general, the'
    ' first connected device is port 5554, the second is 5556, and'
    ' so on.',
)

_TASK = flags.DEFINE_string(
    'task',
    'ExpenseDeleteSingle',
    'A specific task to run.',
)

_APP = flags.DEFINE_string(
    'app_name',
    'pro expense',
    'A specific app to open with',
)

_PACKAGE = flags.DEFINE_string(
    'package_name',
    'com.arduia.expense',
    'A specific package to open with',
)



def _main() -> None:
    """Runs a single task."""
    env = env_launcher.load_and_setup_env(
        console_port=_DEVICE_CONSOLE_PORT.value,
        emulator_setup=_EMULATOR_SETUP.value,
        adb_path=_ADB_PATH.value,
    )
    # 记录结果
    results = []

    for i, task_name in enumerate(TASK_LIST):
        # env = None
        try:
            env.reset(go_home=True)
            task_registry = registry.TaskRegistry()
            aw_registry = task_registry.get_registry(task_registry.ANDROID_WORLD_FAMILY)

            if task_name:
                if task_name not in aw_registry:
                    raise ValueError('Task {} not found in registry.'.format(task_name))
                task_type: Type[task_eval.TaskEval] = aw_registry[task_name]
            else:
                task_type: Type[task_eval.TaskEval] = random.choice(
                    list(aw_registry.values())
                )

            random.seed(2)
            params = task_type.generate_random_params()
            task = task_type(params)
            task.initialize_task(env)

            print("Starting auto exploration.")
            print('task: ' + str(task_name))
            print('Goal: ' + str(task.goal))

            usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
            package, app = get_app_info(task_name)
            adb_utils.launch_app(app, env.controller)

            auto_exploration(
                package_name=package,
                exploration_output_root_dir="./exploration_output",
                device_serial='emulator-5554',
                user_task=task.goal,
                task_dir=task_name,
            )

            # 评估结果
            is_successful = task.is_successful(env) == 1
            print(
                f'{"Task Successful ✅" if is_successful else "Task Failed ❌"};'
                f' {task.goal}'
            )

        except Exception as e:
            # 记录错误并增加连续错误计数
            print(f"Task Error ⚠️ {task_name}: {str(e)}")
            results.append({"task_name": task_name, "success": "error"})

        # 显示当前进度
        successful_count = sum(1 for r in results if r["success"] is True)
        error_count = sum(1 for r in results if r["success"] == "error")
        total_count = len(results)
        success_rate = successful_count / total_count * 100 if total_count > 0 else 0

        print(
            f"Progress: {successful_count}✅ {error_count}⚠️ {total_count - successful_count - error_count}❌ / {total_count} ({success_rate:.1f}%)")

    # 打印汇总
    print("\n" + "=" * 60)
    print("EXECUTION SUMMARY")
    print("=" * 60)

    successful_count = sum(1 for r in results if r["success"] is True)
    failed_count = sum(1 for r in results if r["success"] is False)
    error_count = sum(1 for r in results if r["success"] == "error")
    total_count = len(results)
    success_rate = successful_count / total_count * 100 if total_count > 0 else 0

    for r in results:
        status = '✅' if r['success'] is True else '⚠️' if r['success'] == "error" else '❌'
        print(f"{r['task_name']}: {status}")

    print(f"\nResults: {successful_count}✅ {failed_count}❌ {error_count}⚠️")
    print(f"Success Rate: {successful_count}/{total_count} ({success_rate:.1f}%)")

    # 保存结果
    summary = {
        "results": results,
        "total": total_count,
        "successful": successful_count,
        "failed": failed_count,
        "errors": error_count,
        "success_rate": success_rate
    }

    with open("task_results.json", "w") as f:
        json.dump(summary, f, indent=2)

    print(f"Results saved to task_results.json")
    env.close()


def main(argv: Sequence[str]) -> None:
  del argv
  _main()


if __name__ == '__main__':
  app.run(main)
