import collections
import datetime
import hashlib
import os
import random
import time
from typing import Any, Type

import numpy as np
import pandas as pd
from absl import flags
from android_env import env_interface
from android_world import checkpointer as checkpointer_lib
from android_world import constants
from android_world.env import adb_utils
from android_world.env import interface
from android_world.task_evals import task_eval
from fuzzywuzzy import process

FLAGS = flags.FLAGS
# A fixed seed to use when use identical parameters but seed is not set.
_FIXED_SEED = 123
_TASK_TEMPLATE_COLUMN = 'task_template'
_TASK_PROMPT_COLUMN = 'task_prompt'


class Suite(dict[str, list[task_eval.TaskEval]]):
  """A suite of tasks.

  Each key is the task name as defined in registry.py and its value is a list
  of instantiated task objects. These instances differ from each other by their
  parameter initializations; i.e. each task will have different task parameters.
  """
  
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._suite_family = None
  
  @property
  def suite_family(self) -> str:
    """Getter for suite_family."""
    if self._suite_family is None:
      raise ValueError('Suite family is not set; please first set it.')
    return self._suite_family
  
  @suite_family.setter
  def suite_family(self, value: str):
    """Setter for suite_family."""
    self._suite_family = value


def _instantiate_task(
  task: Type[task_eval.TaskEval],
  params: dict[str, Any] | None = None,
  seed: int | None = None,
  env: interface.AsyncEnv | None = None,
) -> task_eval.TaskEval:
  """Creates an instance of a task with params.

  If params is not provided, it will use random params, controlled by a seed.

  Args:
    task: The task to instantiate.
    params: Params to use.
    seed: Seed for the random number generator.
    env: The environment.

  Returns:
    An instance of a task.
  """
  task.set_device_time(env)
  if not params:  # None or {}
    if seed is not None:
      random.seed(seed)
    params = task.generate_random_params()
    params[constants.EpisodeConstants.SEED] = seed
  return task(params)


def create_suite(
  task_registry: dict[str, Type[task_eval.TaskEval]],
  n_task_combinations: int = 1,
  seed: int | None = None,
  tasks: list[str] | None = None,
  use_identical_params: bool = False,
  env: interface.AsyncEnv | None = None,
) -> Suite:
  """Creates task suite.

  A task suite is a set of tasks. Each task is instantiated
  `n_task_combinations` times using new parameters. For example a task suite
  could look like:

  ```python
  {
      'GoogleSearchTask': [
          GoogleSearchTask({'term': 'cute cats'}),
          GoogleSearchTask({'term': 'comfy pillows'}),
      ],
      'WifiDisable': [  # No params for WiFi task.
          WifiDisable({}),
          WifiDisable({}),
      ],
  }
  ```

  Args:
    task_registry: Maps task names to their TaskEvals.
    n_task_combinations: Number of instances to create per task. Each instance
      will have unique param combinations.
    seed: Seed for the random number generator. Setting the seed will result in
      the same sequence of params for task instantiation per each task.
    tasks: List of task types that should be in the suite. If value is `None`
      all task types and associated instances will be created.
    use_identical_params: If True, each instance of a task, for a total of
      `n_task_combinations`, will have the same params.
    env: The environment that will be run on.

  Returns:
    A mapping of task name to instances of the task.
  """
  
  def _get_instance_seed(name: str, i: int) -> int:
    unique_seed_str = f'{seed}_{name}_{i}'
    return int(hashlib.sha256(unique_seed_str.encode()).hexdigest(), 16) % (
      2 ** 32
    )
  
  suite = {}
  for name, task_type in task_registry.items():
    current = []
    for i in range(n_task_combinations):
      if use_identical_params:
        instance_seed = (
          _get_instance_seed(name, 0) if seed is not None else _FIXED_SEED
        )
      elif seed is not None:
        instance_seed = _get_instance_seed(name, i)
      else:
        instance_seed = None
      
      # MarkorCreateFolder
      params = {}
      if name == 'MarkorCreateFolder':
        params["folder_name"] = f'folder_{i}'
        params[constants.EpisodeConstants.SEED] = instance_seed
      
      current.append(_instantiate_task(task_type, seed=instance_seed, env=env, params=params))
    suite[name] = current
  suite = filter_tasks(suite, task_registry, tasks)
  
  # Sort suite alphabetically by task name.
  return Suite(sorted(suite.items()))


def suggest_keyword(
  typo: str, keywords: list[str], threshold: int = 80
) -> str:
  """Suggests a keyword."""
  suggestion, score = process.extractOne(typo, keywords)
  if score >= threshold:
    return f" Did you mean '{suggestion}'?"
  else:
    return ''


def filter_tasks(
  suite: dict[str, list[task_eval.TaskEval]],
  task_registry: dict[str, Type[task_eval.TaskEval]],
  tasks: list[str] | None = None,
) -> dict[str, list[task_eval.TaskEval]]:
  """Filters a suite by specific tasks.

  Args:
    suite: The suite to retrieve tasks from.
    task_registry: The task registry the suite is from.
    tasks: The tasks to retrieve. If None, just return entire suite.

  Returns:
    A "mini-suite" of tasks from suite.

  Raises:
    ValueError: If invalid task name.
  """
  if tasks is None:
    return suite
  subset = {}
  
  # Validate.
  for name in tasks:
    if name not in task_registry:
      raise ValueError(
        f'Task {name} not found in the task registry.'
        + suggest_keyword(name, list(task_registry.keys()))
      )
  
  # Filter.
  for name, instances in suite.items():
    if name in tasks:
      subset[name] = instances
  return subset


def get_task_info(
  episodes: list[dict[str, Any]],
) -> tuple[dict[str, list[dict[str, Any]]], dict[str, list[dict[str, Any]]]]:
  """Gets task info from episodes.

  Args:
    episodes: Episodes to get info from.

  Returns:
    A tuple of completed and failed task lookup tables.
  """
  
  completed = collections.defaultdict(list)
  failed = collections.defaultdict(list)
  for episode in episodes:
    instance_name = (
      episode[constants.EpisodeConstants.TASK_TEMPLATE]
      + checkpointer_lib.INSTANCE_SEPARATOR
      + str(episode[constants.EpisodeConstants.INSTANCE_ID])
    )
    if episode.get(constants.EpisodeConstants.EXCEPTION_INFO) is not None:
      failed[instance_name].append(episode)
    else:
      completed[instance_name].append(episode)
  return completed, failed


def allocate_step_budget(task_complexity: float) -> int:
  """Allocates number of steps dynamically based on the complexity score.

  Args:
    task_complexity: Complexity score of the task.

  Returns:
    Allocated number of steps for the task.
  """
  if task_complexity is None:
    raise ValueError('Task complexity must be provided.')
  # return 3
  return int(10 * (task_complexity))
  # return int(10 * (task_complexity) + 5)
  # return int(10 * (task_complexity) + 10)


def calculate_max_steps(task_complexity: float, task_name: str = None, log_prefix: str = "") -> int:
  """Calculate maximum steps for a task based on complexity.
  
  This is the unified function used across the entire system for step budget calculation.
  Uses the formula: min(complexity * 10 + 5, 50)
  
  Args:
    task_complexity: Complexity score of the task (typically 0.0 - 5.0)
    task_name: Optional task name for logging
    log_prefix: Optional prefix for log messages (e.g., "Exploration", "Fix React")
    
  Returns:
    Maximum number of steps allocated for this task
  """
  if task_complexity is None:
    raise ValueError('Task complexity must be provided.')
  
  # Calculate steps: complexity * 10 + 5, capped at 50

  base_steps = allocate_step_budget(task_complexity)
  additional_steps = 10
  # additional_steps = 20
  max_steps = base_steps + additional_steps
  # max_steps = min(max_steps, 50)
  
  # Log the calculation (import here to avoid circular dependency)
  try:
    from .utils.agent_utils import print_with_color
    task_info = f" for task '{task_name}'" if task_name else ""
    prefix = f"[{log_prefix}] " if log_prefix else ""
    print_with_color(
      f"{prefix}📊 Max steps calculation{task_info}: "
      f"complexity={task_complexity:.1f} → "
      f"steps={int(task_complexity * 10)}+{additional_steps}={base_steps} → "
      f"final={max_steps} (capped at 50)",
      'cyan'
    )
  except ImportError:
    # Fallback if agent_utils is not available (e.g., in isolated tests)
    task_info = f" for task '{task_name}'" if task_name else ""
    prefix = f"[{log_prefix}] " if log_prefix else ""
    print(
      f"{prefix}📊 Max steps calculation{task_info}: "
      f"complexity={task_complexity:.1f} → "
      f"steps={int(task_complexity * 10)}+{additional_steps}={base_steps} → "
      f"final={max_steps} (capped at 50)"
    )
  
  return max_steps


def display_message(
  header: str, body: str, env: env_interface.AndroidEnvInterface
) -> None:
  adb_utils.send_android_intent(
    'broadcast',
    'com.example.ACTION_UPDATE_OVERLAY',
    env,
    extras={'task_type_string': header, 'goal_string': body},
  )


def display_goal(env: interface.AsyncEnv, task: task_eval.TaskEval) -> None:
  """Displays the goal on the screen using Android World.

  Args:
    env: The environment.
    task: The current task.
  """
  adb_utils.launch_app('android world', env.controller)
  time.sleep(1.0)
  display_message(task.goal, task.name, env.controller)
  time.sleep(6.0)
  adb_utils.press_home_button(env.controller)
  time.sleep(1.0)


def get_screen_config(task: task_eval.TaskEval) -> dict[str, Any]:
  return {
    'width': task.width if hasattr(task, 'width') else 1080,
    'height': task.height if hasattr(task, 'height') else 2400,
    'orientation': (
      task.orientation if hasattr(task, 'orientation') else 'portrait'
    ),
    'config_name': (
      task.config_name if hasattr(task, 'config_name') else 'default'
    ),
  }


def create_failed_result(
  name: str, goal: str, exception: str, run_time: float
) -> dict[str, Any]:
  """Creates empty result to use if the run fails for some reason."""
  return {
    constants.EpisodeConstants.GOAL: goal,
    constants.EpisodeConstants.TASK_TEMPLATE: name,
    constants.EpisodeConstants.EPISODE_DATA: np.nan,
    constants.EpisodeConstants.IS_SUCCESSFUL: np.nan,
    constants.EpisodeConstants.FINISH_DTIME: datetime.datetime.now(),
    constants.EpisodeConstants.RUN_TIME: run_time,
    constants.EpisodeConstants.EPISODE_LENGTH: np.nan,
    constants.EpisodeConstants.EXCEPTION_INFO: exception,
    constants.EpisodeConstants.AUX_DATA: None,
  }


def display_success_overlay(
  env: env_interface.AndroidEnvInterface, success: float
) -> None:
  """Displays success overlay."""
  adb_utils.send_android_intent(
    'broadcast',
    'com.example.ACTION_UPDATE_OVERLAY',
    env,
    extras={'success_string': str(int(success))},
  )
  time.sleep(1.0)  # Let display linger.


def update_scoreboard(
  n_correct: int, n: int, env: env_interface.AndroidEnvInterface
) -> None:
  """Updates the scoreboard."""
  percentage = (n_correct / n) * 100
  scoreboard_value = f'{n_correct}/{n} ({percentage:.1f}%)'
  
  adb_utils.send_android_intent(
    'broadcast',
    'com.example.ACTION_UPDATE_SCOREBOARD',
    env,
    extras={'scoreboard_value': scoreboard_value},
  )


def extract_task_metadata() -> pd.DataFrame:
  """Extracts metadata from task_metadata.json.
  
  Returns an empty DataFrame with the expected columns if the file doesn't exist.
  """
  name = 'task_metadata.json'
  filepath = os.path.join('android_world', name)
  # filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), name)
  
  if not os.path.exists(filepath):
    # Return empty DataFrame with expected structure if file doesn't exist
    return pd.DataFrame(columns=['difficulty', 'optimal_steps', 'tags']).set_index(
      pd.Index([], name=_TASK_TEMPLATE_COLUMN)
    )
  
  df = pd.read_json(filepath)
  df.rename(columns={_TASK_TEMPLATE_COLUMN: _TASK_PROMPT_COLUMN}, inplace=True)
  df.rename(columns={'task_name': _TASK_TEMPLATE_COLUMN}, inplace=True)
  return df.set_index(_TASK_TEMPLATE_COLUMN)[
    ['difficulty', 'optimal_steps', 'tags']
  ]


def print_results_by_tag(result_df: pd.DataFrame) -> None:
  exploded_df = result_df.explode('tags').reset_index()
  exploded_df.replace(regex={'tags': r''}, value='untagged', inplace=True)  # pytype: disable=wrong-arg-types
  return (
    exploded_df.groupby(['tags', 'difficulty'], as_index=False)
    .agg(
      num_tasks=(_TASK_TEMPLATE_COLUMN, 'count'),
      mean_success_rate=('mean_success_rate', 'mean'),
    )
    .pivot_table(
      index=['tags'],
      columns='difficulty',
      values=[
        'mean_success_rate',
      ],
    )
    .fillna('-')
    .reindex(columns=['easy', 'medium', 'hard'], level='difficulty')
  )


def process_episodes(
  episodes: list[dict[str, Any]], print_summary: bool = False
) -> pd.DataFrame:
  """Processes task suite results; i.e. the output from `run_task_suite`.

  results = run_task_suite(...)
  # Contents of results.
  results = [
    {
        'goal': 'Pause the stopwatch.',
        'task_template': 'ClockStopWatchPaused',
        'episode_data': ...,
        'is_successful': True
    },
    {
        'goal': 'Pause the stopwatch.',
        'task_template': 'ClockStopWatchPaused',
        'episode_data': ...,
        'is_successful': False
    },
    {
        'goal': 'Run the stopwatch.',
        'task_template': 'ClockStopWatchRunnin',
        'episode_data': ...,
        'is_successful': True
    },
    {
        'goal': 'Run the stopwatch.',
        'task_template': 'ClockStopWatchRunnin',
        'episode_data': ...,
        'is_successful': True
    }
  ]

  process_episodes(results)
  # Output:
  # | task_template               |   n_trials |   average_success_rate |
  # |:----------------------------|-----------:|-----------------------:|
  # | ClockStopWatchPausedVerify  |          2 |                   0.5  |
  # | ClockStopWatchRunning       |          2 |                   1    |
  # | ==========Average========== |          2 |                   0.75 |

  Args:
    episodes: Results from running `run_task_suite`.
    print_summary: Whether to print the dataframe with a summary row.

  Returns:
    A dataframe aggregating results of run.
  """
  
  df = pd.DataFrame(list(episodes))
  
  # Add exeception info for backwards compatibility.
  df = df.assign(**{
    constants.EpisodeConstants.EXCEPTION_INFO: df.get(
      constants.EpisodeConstants.EXCEPTION_INFO, np.nan
    )
  })
  
  result_df = df.groupby(
    constants.EpisodeConstants.TASK_TEMPLATE, dropna=True
  ).agg({
    constants.EpisodeConstants.IS_SUCCESSFUL: ['count', 'mean'],
    constants.EpisodeConstants.EPISODE_LENGTH: 'mean',
    constants.EpisodeConstants.RUN_TIME: 'sum',
    constants.EpisodeConstants.EXCEPTION_INFO: [
      ('none_count', lambda x: x.notnull().sum())
    ],
  })
  result_df = result_df.sort_index()
  result_df.columns = [
    'num_complete_trials',
    'mean_success_rate',
    'mean_episode_length',
    'total_runtime_s',
    'num_fail_trials',
  ]
  result_df['total_runtime_s'] = result_df['total_runtime_s'].map(
    lambda x: float('{:.1f}'.format(x))
  )
  
  # Extract metadata and merge with the results table.
  metadata_df = extract_task_metadata()
  tagged_result_df = result_df.merge(
    metadata_df, on=[_TASK_TEMPLATE_COLUMN], how='left'
  )
  
  # susu save two files
  if print_summary:
    avg = result_df.mean(axis=0)
    avg.name = '========= Average ========='
    
    result = pd.concat([result_df, avg.to_frame().T])
    result.index.name = 'task'
    result.insert(0, 'task_num', list(range(len(result) - 1)) + [0])
    result.task_num = result.task_num.astype(int)
    pd.set_option('display.max_columns', 100)
    pd.set_option('display.width', 1000)
    print(f'\n\n{result}')
    file1_path = os.path.join(FLAGS.log_folder_exp, 'result.csv')
    
    file = open(file1_path, 'w')
    result.to_csv(file)
    file.flush()
    file.close()
    
    # Add a chart that shows mean success rate by tag and difficulty.
    tags_df = print_results_by_tag(tagged_result_df)
    pd.set_option('display.precision', 2)
    print(f'\n\n{tags_df}')
    file2_path = os.path.join(FLAGS.log_folder_exp, 'tags_df.csv')
    
    file = open(file2_path, 'w')
    tags_df.to_csv(file)
    file.flush()
    file.close()
  
  return tagged_result_df
