import functools
import warnings
warnings.filterwarnings('ignore')
from absl import app
from absl import flags
import acme
from acme.agents.jax import actor_core as actor_core_lib
from acme.agents.jax import actors
from acme.jax import variable_utils
from acme.utils import counting
import jax
from ml_collections import config_flags
import numpy as np
import optax
import tqdm

from otr import dataset_utils_my_scale_detach
from otr import evaluation
from otr import experiment_utils
from otr.agents import iql
from otr.agents.otil import rewarder as rewarder_lib

_CONFIG = config_flags.DEFINE_config_file("config", "configs/otr_iql_mujoco.py")
_WORKDIR = flags.DEFINE_string('workdir', '/tmp/otr', '')


def relabel_rewards(rewarder, trajectory):
  rewards = rewarder.compute_offline_rewards(trajectory)
  relabeled_transitions = []
  for transition, reward in zip(trajectory, rewards):
    relabeled_transitions.append(transition._replace(reward=reward))
  return relabeled_transitions


def compute_iql_reward_scale(trajs):
  """Rescale rewards based on max/min from the dataset.
  This is also used in the original IQL implementation.
  """
  trajs = trajs.copy()

  def compute_returns(tr):
    return sum([step.reward for step in tr])

  trajs.sort(key=compute_returns)
  reward_scale = 1000.0 / (
      compute_returns(trajs[-1]) - compute_returns(trajs[0]))
  return reward_scale


def get_demonstration_dataset(config, offline_dataset_name):
  """Return the relabeled offline dataset."""
  offline_dataset_name = offline_dataset_name
  offline_traj = dataset_utils_my_scale_detach.load_trajectories(offline_dataset_name, config.data_load, config.score_lambda)
  #print(len(offline_traj))
  if "antmaze" in offline_dataset_name:
    reward_scale = compute_iql_reward_scale(offline_traj)
    print('==========', reward_scale)
    reward_bias = -0.0
  else:
    reward_scale = compute_iql_reward_scale(offline_traj)
    print('=========', reward_scale)
    reward_bias = 0.0
  relabeled_transitions = dataset_utils_my_scale_detach.merge_trajectories(offline_traj)
  relabeled_transitions = relabeled_transitions._replace(reward=relabeled_transitions.reward * reward_scale + reward_bias)
  return relabeled_transitions


def main(_):
  config = _CONFIG.value
  offline_dataset_name = 'ant-random-v2'
  workdir = './tmp_class_ant/ant_random/class_'+str(config.seed)
  log_to_wandb = config.log_to_wandb
  config.data_load = 'datasets_class_scores/ant_random/ant_class_'+str(config.seed)+'-oriscores'
  
  wandb_kwargs = {
      'project': config.wandb_project,
      'entity': config.wandb_entity,
      'config': config.to_dict(),
  }

  logger_factory = experiment_utils.LoggerFactory(
      workdir=workdir,
      log_to_wandb=log_to_wandb,
      wandb_kwargs=wandb_kwargs,
      learner_time_delta=30,
      evaluator_time_delta=0)

  dataset = get_demonstration_dataset(config, offline_dataset_name)

  # Create dataset iterator for the relabeled dataset
  key = jax.random.PRNGKey(config.seed)
  key_learner, key_demo, key = jax.random.split(key, 3)

  iterator = dataset_utils_my_scale_detach.JaxInMemorySampler(dataset, key_demo,
                                              config.batch_size)

  # Create an environment and grab the spec.
  environment = dataset_utils_my_scale_detach.make_environment(
      offline_dataset_name, seed=config.seed)
  # Create the networks to optimize.
  spec = acme.make_environment_spec(environment)
  networks = iql.make_networks(
      spec, hidden_dims=config.hidden_dims, dropout_rate=config.dropout_rate)

  counter = counting.Counter(time_delta=0.0)

  if config.opt_decay_schedule == "cosine":
    schedule_fn = optax.cosine_decay_schedule(-config.actor_lr,
                                              config.max_steps)
    policy_optimizer = optax.chain(optax.scale_by_adam(),
                                   optax.scale_by_schedule(schedule_fn))
  else:
    policy_optimizer = optax.adam(config.actor_lr)

  # Create the learner.
  learner_counter = counting.Counter(counter, "learner", time_delta=0.0)
  learner = iql.IQLLearner(
      networks=networks,
      random_key=key_learner,
      dataset=iterator,
      policy_optimizer=policy_optimizer,
      critic_optimizer=optax.adam(config.critic_lr),
      value_optimizer=optax.adam(config.value_lr),
      **config.iql_kwargs,
      logger=logger_factory('learner', learner_counter.get_steps_key(), 0),
      counter=learner_counter,
  )

  def evaluator_network(params, key, observation):
    del key
    action_distribution = networks.policy_network.apply(
        params, observation, is_training=False)
    return action_distribution.mode()

  eval_actor = actors.GenericActor(
      actor_core_lib.batched_feed_forward_to_actor_core(evaluator_network),
      random_key=key,
      variable_client=variable_utils.VariableClient(
          learner, "policy", device="cpu"),
      backend="cpu",
  )

  eval_counter = counting.Counter(counter, "eval_loop", time_delta=0.0)
  eval_loop = evaluation.D4RLEvalLoop(
      environment,
      eval_actor,
      counter=eval_counter,
      logger=logger_factory('eval_loop', eval_counter.get_steps_key(), 0),
  )

  # Run the environment loop.
  steps = 0
  while steps < 100000:
    for _ in range(config.evaluate_every):
      learner.step()
    steps += config.evaluate_every
  eval_loop.run(1)


if __name__ == '__main__':
  app.run(main)
