"""Eval job using a variable container to fetch the weights of the policy."""

import collections
import os
import statistics
import time
from typing import Text
import numpy as np

from absl import logging
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)


from tf_agents.experimental.distributed import reverb_variable_container
from tf_agents.metrics import py_metric
from tf_agents.metrics import py_metrics
from tf_agents.policies import greedy_policy  # pylint: disable=unused-import
from tf_agents.policies import py_tf_eager_policy
from tf_agents.train import actor
from tf_agents.train import learner
from tf_agents.train.utils import train_utils
from tf_agents.trajectories import trajectory
from tf_agents.utils import common


class InfoMetric(py_metric.PyStepMetric):
  """Observer for graphing the environment info metrics."""

  def __init__(
      self,
      env,
      info_metric_key: Text,
      buffer_size: int = 1,
      name: Text = 'InfoMetric',
  ):
    """Observer reporting TensorBoard metrics at the end of each episode.

    Args:
      env: environment.
      info_metric_key: a string key from the environment info to report,
        e.g. wirelength, density, congestion.
      buffer_size: size of the buffer for calculating the aggregated metrics.
      name: name of the observer object.
    """
    super(InfoMetric, self).__init__(name + '_' + info_metric_key)

    self._env = env
    self._info_metric_key = info_metric_key
    self._buffer = collections.deque(maxlen=buffer_size)

  def call(self, traj: trajectory.Trajectory):
    """Report the requested metrics at the end of each episode."""

    # We collect the metrics from the info from the environment instead.
    # The traj argument is kept to be compatible with the actor/learner API
    # for metrics.
    del traj

    if self._env.done:
      metric_value = self._env.get_info()[self._info_metric_key]
      self._buffer.append(metric_value)

  def result(self):
    return statistics.mean(self._buffer)

  def reset(self):
    self._buffer.clear()


def evaluate(root_dir, variable_container_server_address, create_env_fn):
  """Evaluates greedy policy."""

  eval_return_values = []
  num_episodes_values = []
  avg_episode_length_values = []

  congestion_values = []
  wirelength_values = []
  density_values = []

  cost_track = []

  # Create the path for the serialized greedy policy.
  policy_saved_model_path = os.path.join(root_dir,
                                         learner.POLICY_SAVED_MODEL_DIR,
                                         learner.GREEDY_POLICY_SAVED_MODEL_DIR)
  saved_model_pb_path = os.path.join(policy_saved_model_path, 'saved_model.pb')
  try:
    # Wait for the greedy policy to be outputed by learner (timeout after 2
    # days), then load it.
    train_utils.wait_for_file(
        saved_model_pb_path, sleep_time_secs=2, num_retries=86400)
    policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        policy_saved_model_path, load_specs_from_pbtxt=True)
  except TimeoutError as e:
    # If the greedy policy does not become available during the wait time of
    # the call `wait_for_file`, that probably means the learner is not running.
    logging.error('Could not get the file %s. Exiting.', saved_model_pb_path)
    raise e

  # Create the variable container.
  train_step = train_utils.create_train_step()
  model_id = common.create_variable('model_id')

  # Create the environment.
  env = create_env_fn()
  variables = {
      reverb_variable_container.POLICY_KEY: policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step,
      'model_id': model_id,
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.update(variables)

  # Create the evaluator actor.
  info_metrics = [
      InfoMetric(env, 'wirelength'),
      InfoMetric(env, 'congestion'),
      InfoMetric(env, 'density'),
  ]
  eval_actor = actor.Actor(
      env,
      policy,
      train_step,
      episodes_per_run=1,
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, 'eval'),
      metrics=[
          py_metrics.NumberOfEpisodes(),
          py_metrics.EnvironmentSteps(),
          py_metrics.AverageReturnMetric(
              name='eval_episode_return', buffer_size=1),
          py_metrics.AverageEpisodeLengthMetric(buffer_size=1),
      ] + info_metrics,
      name='performance')

  summary_writer = tf.summary.create_file_writer(os.path.join(root_dir, learner.TRAIN_DIR, 'eval'))

  # Run the experience evaluation loop.
  while True:
    eval_actor.run()
    variable_container.update(variables)
    current_step = train_step.numpy()
    logging.info('Evaluating using greedy policy at step: %d', current_step)

    # Additional print statements for evaluation progress
    eval_return = py_metrics.AverageReturnMetric(name='eval_episode_return', buffer_size=1).result()
    num_episodes = py_metrics.NumberOfEpisodes().result()
    avg_episode_length = py_metrics.AverageEpisodeLengthMetric(buffer_size=1).result()

    # NEWLY ADDED
    congestion = info_metrics[1].result()  # Congestion metric
    wirelength = info_metrics[0].result()  # Wirelength metric
    density = info_metrics[2].result()  # Density metric
    cost = float((0.5*density)) + float((0.5*congestion)) + float(wirelength)

    # NEWLY ADDED
    eval_return_values.append(eval_return)
    num_episodes_values.append(num_episodes)
    avg_episode_length_values.append(avg_episode_length)
    congestion_values.append(congestion)
    wirelength_values.append(wirelength)
    density_values.append(density)
    cost_track.append(cost)


    print(f"Step: {current_step}, Evaluation Return: {eval_return}, "
              f"Number of Episodes: {num_episodes}, Average Episode Length: {avg_episode_length}, "
              f"Congestion: {congestion}, Wirelength: {wirelength}, Density: {density}")
 
    print("Cost_List", cost_track)
    eval_actor.write_metric_summaries()

    # Write out summaries at the end of each evaluation iteration.
    with summary_writer.as_default():
        #tf.summary.scalar('eval_return', eval_return, step=current_step)
        #tf.summary.scalar('num_episodes', num_episodes, step=current_step)
        #tf.summary.scalar('avg_episode_length', avg_episode_length, step=current_step)
        tf.summary.scalar('congestion', congestion, step=current_step)
        tf.summary.scalar('wirelength', wirelength, step=current_step)
        tf.summary.scalar('density', density, step=current_step)
        tf.summary.scalar('Cost', cost, step=current_step)

    time.sleep(20)