# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contrastive RL learner implementation."""
import time
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Callable

import acme
from acme import types
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.utils import counting
from acme.utils import loggers
from contrastive import config as contrastive_config
from contrastive import networks as contrastive_networks
import functools
import jax
import jax.numpy as jnp
import optax
import reverb
import numpy as np
from jax.flatten_util import ravel_pytree


class TrainingState(NamedTuple):
  """Contains training state for the learner."""
  policy_optimizer_state: optax.OptState
  q_optimizer_state: optax.OptState
  policy_params: networks_lib.Params
  q_params: networks_lib.Params
  target_q_params: networks_lib.Params
  key: networks_lib.PRNGKey
  critic_mean_grad: jnp.array
  actor_mean_grad: jnp.array
  alpha_optimizer_state: Optional[optax.OptState] = None
  alpha_params: Optional[networks_lib.Params] = None


class ContrastiveLearner(acme.Learner):
  """Contrastive RL learner."""

  _state: TrainingState

  def __init__(
      self,
      networks,
      rng,
      policy_optimizer,
      q_optimizer,
      iterator,
      counter,
      logger,
      obs_to_goal,
      config):
    """Initialize the Contrastive RL learner.

    Args:
      networks: Contrastive RL networks.
      rng: a key for random number generation.
      policy_optimizer: the policy optimizer.
      q_optimizer: the Q-function optimizer.
      iterator: an iterator over training data.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      obs_to_goal: a function for extracting the goal coordinates.
      config: the experiment config file.
    """
    if config.add_mc_to_td:
      assert config.use_td
    adaptive_entropy_coefficient = config.entropy_coefficient is None
    self._num_sgd_steps_per_step = config.num_sgd_steps_per_step
    self.policy_frequency = config.policy_frequency
    self._obs_dim = config.obs_dim
    self._use_td = config.use_td
    
    ###############################################################################################
    self.mid_loss_mask = np.zeros((config.num_classifier_bins, config.num_classifier_bins))
    for i in range(config.num_classifier_bins):
      self.mid_loss_mask[i, i:-1] = 1
    self.mid_loss_mask = jnp.array(self.mid_loss_mask)[:, None, :]
    
    self.classifier_weights = np.zeros((config.num_classifier_bins,config.num_classifier_bins))
    n_cols = config.num_classifier_bins
    
    self.classifier_weights[0, -1] = 1 # Corresponds to 0 shift => Identical (edge case)
    
    for i in range(config.num_classifier_bins-1):
      self.classifier_weights[i+1, n_cols-2-i] = 1 - config.discount
      for j in range(i):
        self.classifier_weights[i+1, n_cols-1-i+j] = config.discount * self.classifier_weights[i+1, n_cols-2-i+j]
      self.classifier_weights[i+1, n_cols-1] = 1 - self.classifier_weights[i+1].sum()

    # self.classifier_weights = config.discount * np.ones(config.num_classifier_bins)
    # self.classifier_weights[0] =  1 - config.discount
    # self.classifier_weights = np.cumprod(self.classifier_weights)
    # self.classifier_weights[-1] = 1 - self.classifier_weights[:-1].sum()
    self.classifier_weights = jnp.array(self.classifier_weights)[:, None, None, :, None]
    ###############################################################################################
    
    if adaptive_entropy_coefficient:
      # alpha is the temperature parameter that determines the relative
      # importance of the entropy term versus the reward.
      log_alpha = jnp.asarray(0., dtype=jnp.float32)
      alpha_optimizer = optax.adam(learning_rate=3e-4)
      alpha_optimizer_state = alpha_optimizer.init(log_alpha)
    else:
      if config.target_entropy:
        raise ValueError('target_entropy should not be set when '
                         'entropy_coefficient is provided')

    def alpha_loss(log_alpha,
                   policy_params,
                   transitions,
                   key):
      """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
      dist_params = networks.policy_network.apply(
          policy_params, transitions.observation)
      action = networks.sample(dist_params, key)
      log_prob = networks.log_prob(dist_params, action)
      alpha = jnp.exp(log_alpha)
      alpha_loss = alpha * jax.lax.stop_gradient(
          -log_prob - config.target_entropy)
      return jnp.mean(alpha_loss)

    def critic_loss(q_params,
                    policy_params,
                    target_q_params,
                    transitions,
                    key):
      batch_size = transitions.observation.shape[0]
      # Note: We might be able to speed up the computation for some of the
      # baselines to making a single network that returns all the values. This
      # avoids computing some of the underlying representations multiple times.
      if config.use_td:
        # For TD learning, the diagonal elements are the immediate next state.
        s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)
        next_s, _ = jnp.split(transitions.next_observation, [config.obs_dim],
                              axis=1)
        if config.add_mc_to_td:
          next_fraction = (1 - config.discount) / ((1 - config.discount) + 1)
          num_next = int(batch_size * next_fraction)
          new_g = jnp.concatenate([
              obs_to_goal(next_s[:num_next]),
              g[num_next:],
          ], axis=0)
        else:
          new_g = obs_to_goal(next_s) # TD: 1-step future state is the goal
        obs = jnp.concatenate([s, new_g], axis=1)
        transitions = transitions._replace(observation=obs)
      I = jnp.eye(batch_size)  # pylint: disable=invalid-name
      logits = networks.q_network.apply(
          q_params, transitions.observation, transitions.action)

      if config.use_td:
        # Make sure to use the twin Q trick.
        assert len(logits.shape) == 4

        # We evaluate the next-state Q function using random goals
        # Note: The goals g were overwritten with s_{t+1} for computing the positive logits above
        s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)
        del s
        next_s = transitions.next_observation[:, :config.obs_dim]
        goal_indices = jnp.roll(jnp.arange(batch_size, dtype=jnp.int32), -1)
        g = g[goal_indices]
        transitions = transitions._replace(
            next_observation=jnp.concatenate([next_s, g], axis=1))

        next_dist_params = networks.policy_network.apply(
            policy_params, transitions.next_observation)
        next_action = networks.sample(next_dist_params, key) # This action is conditioned on the goal g
        next_q = networks.q_network.apply(target_q_params,
                                          transitions.next_observation,
                                          next_action)  # This outputs logits.

        if config.use_cpc:
          raise NotImplementedError

        else:
          next_q = jax.nn.sigmoid(next_q) # [B, B, nC, 2] - last axis has q values from twin critics
          next_v = jnp.min(next_q, axis=-1) # [B, B, nC]
          next_v = jax.lax.stop_gradient(next_v)
          next_v = jax.vmap(jnp.diag, -1, -1)(next_v) # [B, nC]
          
          # diag(logits) are predictions for future states.
          # diag(next_q) are predictions for random states, which correspond to
          # the predictions logits[range(B), goal_indices].
          w = next_v / (1 - next_v)
          w_clipping = 20.0
          w = jnp.clip(w, 0, w_clipping)
          
          # First Classifier Loss
          # (B, B, nC, 2) --> (B, 2), computes diagonal of each twin Q.
          pos_logits = jax.vmap(jnp.diag, -1, -1)(logits[:,:,0,:])                          # [B, 2]
          neg_logits = logits[jnp.arange(batch_size), goal_indices]                         # [B, nC, 2]
          first_loss_pos = optax.sigmoid_binary_cross_entropy(logits=pos_logits, labels=1)        # [B, 2]
          first_loss_neg = optax.sigmoid_binary_cross_entropy(logits=neg_logits[:, 0], labels=0)  # [B, 2]
          loss = jnp.mean(first_loss_pos + first_loss_neg)
          
          # Middle Classifier Loss
          if config.num_classifier_bins > 2:
            mid_loss_pos = w[:, :-2, None] * optax.sigmoid_binary_cross_entropy(logits=neg_logits[:, 1:-1], labels=1)
            mid_loss_neg = optax.sigmoid_binary_cross_entropy(logits=neg_logits[:, 1:-1], labels=0)
            loss += jnp.mean(mid_loss_pos + mid_loss_neg)
          
          # Final Classifier Loss
          final_loss_pos = ((1-config.discount)*w[:, -2, None] + config.discount*w[:, -1, None]) * \
                                    optax.sigmoid_binary_cross_entropy(logits=neg_logits[:, -1], labels=1)
          final_loss_neg = optax.sigmoid_binary_cross_entropy(logits=neg_logits[:, -1], labels=0)
          loss += jnp.mean(final_loss_pos + final_loss_neg)

          if config.add_mc_to_td:
            raise NotImplementedError
            loss = ((1 + (1 - config.discount)) * loss_pos
                    + config.discount * loss_neg1 + 2 * loss_neg2)
          
          logits = jax.nn.logsumexp(a = logits, b = self.classifier_weights[-1], axis=2)
          # logits = jnp.log((self.classifier_weights[-1] * jnp.clip(jnp.exp(logits), 0, w_clipping)).sum(2))
          logits = jnp.mean(logits, axis=-1)

      else:  # For the MC losses.
        
        targets = jnp.zeros((batch_size, batch_size, config.num_classifier_bins))
        targets = targets.at[jnp.arange(batch_size), jnp.arange(batch_size),
                             transitions.extras['relative_goal_idx'] - 1].set(1)
        # targets[jnp.arange(batch_size), jnp.arange(batch_size), transitions.extras['goal_idx']] = 1
        def loss_fn(_logits):  # pylint: disable=invalid-name
          if config.use_cpc:
            _logits_ = jnp.zeros((batch_size, batch_size))
            _logits_ = _logits.at[jnp.arange(batch_size), :, 
                                   transitions.extras['relative_goal_idx'] - 1].get()

            if config.cpc_term_state:
              raise NotImplementedError
              num_term_states = int((1-config.discount)*(batch_size-1)/config.discount)
              num_term_states = max(num_term_states, 1)
              modified_logits = jnp.hstack((_logits, jnp.zeros((batch_size,num_term_states))))
              modified_labels = jnp.hstack((config.discount*I,
                ((1-config.discount)/num_term_states)*jnp.ones((batch_size,num_term_states))))
              result = optax.softmax_cross_entropy(logits=modified_logits, labels=modified_labels)
            else:
              result = optax.softmax_cross_entropy(logits=_logits_, labels=I)
            if not config.no_cpc_regLSE:
              result += jnp.mean(config.cpc_reg_coeff * (jax.nn.logsumexp(_logits, axis=1) - jnp.log(batch_size))**2, axis = -1)
            return result
          else:
            # Pos-Neg Balanced NCE loss
            # full_bce_loss = optax.sigmoid_binary_cross_entropy(logits=_logits, labels=targets)
            # pos_loss = (targets*full_bce_loss).mean()
            # neg_loss = ((1-targets)*full_bce_loss).mean()
            # return pos_loss + neg_loss


            # return optax.sigmoid_binary_cross_entropy(logits=_logits, labels=targets)

            # Every slice along the num_classifier channel sees a positive goal
            # example for (s,a,H) for every (B-1) negative goal examples
            _logits_ = jnp.zeros((batch_size, batch_size))
            _logits_ = _logits.at[jnp.arange(batch_size), :, 
                                   transitions.extras['relative_goal_idx'] - 1].get()
            return optax.sigmoid_binary_cross_entropy(logits=_logits_, labels=I)
        
        if len(logits.shape) == 4:  # twin q
          # loss.shape = [.., num_q]
          loss = jax.vmap(loss_fn, in_axes=-1, out_axes=-1)(logits)
          loss = jnp.mean(loss, axis=-1)
          # Take the mean here so that we can compute the accuracy.
          logits = jnp.mean(logits, axis=-1)
        else:
          loss = jnp.mean(loss_fn(logits))
        
        ########################################################################################
        if config.selfsup_flag:
          _, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)
          # next_s, _ = jnp.split(transitions.next_observation, [config.obs_dim], axis=1)
          selfsup_s = transitions.extras['selfsup_state']
          # td_logits = networks.q_network.apply(q_params, # [B, B, nC]
          #                                      jnp.concatenate([s, obs_to_goal(next_s)], axis=1),
          #                                      transitions.action)
          next_q = networks.q_network.apply(q_params, # [B, B, nC]
                                            jnp.concatenate([selfsup_s, g], axis=1),
                                            transitions.extras['selfsup_action'])
          
          if config.use_cpc:
            next_v = jax.nn.softmax(next_q, axis=1)
            next_v = jax.lax.stop_gradient(next_v) # [B, B, nC]
            
            next_v = jax.vmap(jnp.diag, -1, -1)(next_v) # [B, nC] 
            
            imp_weight = (next_v * self.classifier_weights[transitions.extras['relative_selfsup_idx'],0,0,:,0]).sum(-1)  # [B]

            def shuffle(matrix, roll_ind):
              # matrix -> [B, nC]
              # roll_ind -> scalar
              return jnp.roll(matrix, roll_ind, axis=-1)
            
            rotated_next_v = jax.vmap(shuffle, in_axes=0, out_axes=0)( # ([B,nC], [B,]) -> [B,nC]
              next_v,transitions.extras['relative_selfsup_idx'])

            # Middle Classifier Loss
            def CELoss(logits, labels):
              # logits -> B x B | labels -> B
              soft_logits = jax.nn.log_softmax(logits, axis=-1) # B x B
              soft_logits = jnp.diag(soft_logits) # B
              return optax.sigmoid_binary_cross_entropy(logits=soft_logits, labels=labels)
              # return optax.softmax_cross_entropy(logits=logits, labels=labels)

            mid_loss = jax.vmap(CELoss, in_axes=-1, out_axes=-1)(logits, rotated_next_v) # B x nC
     
            if config.num_classifier_bins > 2:
              loss += jnp.mean(self.mid_loss_mask[transitions.extras['relative_selfsup_idx'], 0] * mid_loss)
            # Final Classifier Loss
            soft_logits = jax.nn.log_softmax(logits[:,:,-1], axis=-1) # B x B
            soft_logits = jnp.diag(soft_logits) # B
            loss += jnp.mean(optax.sigmoid_binary_cross_entropy(logits=soft_logits, labels=imp_weight))
            # loss += jnp.mean(optax.softmax_cross_entropy(logits=logits[:,:,-1], labels=imp_weight))

          else:
            next_v = jax.nn.sigmoid(next_q) 
            next_v = jax.lax.stop_gradient(next_v) # [B, B, nC]
            next_v = jax.vmap(jnp.diag, -1, -1)(next_v) # [B, nC]

            w = next_v / (1 - next_v) 
            w_clipping = 20.0
            w = jnp.clip(w, 0, w_clipping) # [B, nC]
            
            imp_weight = (w * self.classifier_weights[transitions.extras['relative_selfsup_idx'],0,0,:,0]).sum(-1)
            # imp_weight = ((1-config.discount)*w[:,:,-2] + config.discount*w[:,:,-1])
            imp_weight = imp_weight / (1 + imp_weight)  # [B]
            
            # # First Classifier Loss
            # first_pos_logits = jnp.diag(td_logits[...,0]) # [B]
            # first_loss_pos = optax.sigmoid_binary_cross_entropy(logits=first_pos_logits, labels=1) # [B]
            # goal_indices = jnp.roll(jnp.arange(batch_size, dtype=jnp.int32), -1)
            # first_neg_logits = td_logits[jnp.arange(batch_size), goal_indices, 0] # [B]
            # first_loss_neg = optax.sigmoid_binary_cross_entropy(logits=first_neg_logits, labels=0) # [B]
            # loss += jnp.mean(first_loss_pos + first_loss_neg)
            
            def shuffle(matrix, roll_ind):
              # matrix -> [B, nC]
              # roll_ind -> scalar
              return jnp.roll(matrix, roll_ind, axis=-1)
            
            rotated_next_v = jax.vmap(shuffle, in_axes=0, out_axes=0)( # ([B,nC], [B,]) -> [B,nC]
              next_v,transitions.extras['relative_selfsup_idx']) # B x nC
            
            diag_logits = jax.vmap(jnp.diag, -1, -1)(logits) # [B, nC]

            # Middle Classifier Loss
            if config.num_classifier_bins > 2:
              loss += jnp.mean(self.mid_loss_mask[transitions.extras['relative_selfsup_idx'],0] * optax.sigmoid_binary_cross_entropy(logits=diag_logits, labels=rotated_next_v))
            # Final Classifier Loss
            loss += jnp.mean(optax.sigmoid_binary_cross_entropy(logits=diag_logits[:,-1], labels=imp_weight))
        ########################################################################################
        if config.use_cpc:
          logits = logits - jax.nn.logsumexp(logits, axis=1, keepdims=True)

          # logits = jax.nn.softmax(logits, axis = 1)
          # logits = jnp.clip(logits/(1-logits), 0, 20.0)
          # logits = (batch_size-1)*(self.classifier_weights[-1,..., 0] * logits).sum(2)
          # logits = jnp.log(logits)

        logits = jax.nn.logsumexp(a = logits, b = self.classifier_weights[-1,..., 0], axis=2)

      loss = jnp.mean(loss)
      correct = (jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1))
      logits_pos = jnp.sum(logits * I) / jnp.sum(I)
      logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)
      if len(logits.shape) == 3:
        logsumexp = jax.nn.logsumexp(logits[:, :, 0], axis=1)**2
      else:
        logsumexp = jax.nn.logsumexp(logits, axis=1)**2
      metrics = {
          'binary_accuracy': jnp.mean((logits > 0) == I),
          'categorical_accuracy': jnp.mean(correct),
          'logits_pos': logits_pos,
          'logits_neg': logits_neg,
          'logsumexp': logsumexp.mean(),
      }

      return loss, metrics

    def actor_loss(policy_params,
                   q_params,
                   alpha,
                   transitions,
                   key,
                   ):
      obs = transitions.observation
      if config.use_gcbc:
        dist_params = networks.policy_network.apply(
            policy_params, obs)
        log_prob = networks.log_prob(dist_params, transitions.action)
        actor_loss = -1.0 * jnp.mean(log_prob)
      else:
        state = obs[:, :config.obs_dim]
        goal = obs[:, config.obs_dim:]

        if config.random_goals == 0.0:
          new_state = state
          new_goal = goal
        elif config.random_goals == 0.5:
          new_state = jnp.concatenate([state, state], axis=0)
          new_goal = jnp.concatenate([goal, jnp.roll(goal, 1, axis=0)], axis=0)
        else:
          assert config.random_goals == 1.0
          new_state = state
          new_goal = jnp.roll(goal, 1, axis=0)

        new_obs = jnp.concatenate([new_state, new_goal], axis=1)
        dist_params = networks.policy_network.apply(
            policy_params, new_obs)
        action = networks.sample(dist_params, key)
        log_prob = networks.log_prob(dist_params, action)
        q_action = networks.q_network.apply(
            q_params, new_obs, action)  # [B, B, nC, 2]
        if len(q_action.shape) == 4:  # twin q trick
          assert q_action.shape[3] == 2
          q_action = jnp.min(q_action, axis=-1) # [B, B, nC]
        
        # Condensing all the classifiers into one
        if config.use_cpc:
          q_action = q_action - jax.nn.logsumexp(q_action, axis=1, keepdims=True)
          q_action = jax.nn.logsumexp(a = q_action, b = self.classifier_weights[-1,..., 0], axis=2)
          # batch_size = transitions.observation.shape[0]
          # q_action = jax.nn.softmax(q_action, axis = 1)
          # q_action = jnp.clip(q_action/(1-q_action), 0, 20.0)
          # q_action = (batch_size-1)*(self.classifier_weights[-1,..., 0] * q_action).sum(2)
          # # q_action = jnp.log(q_action)
        else:
          # q_action = jax.nn.logsumexp(a = q_action, axis=2)
          # q_action = jax.nn.logsumexp(jax.nn.log_sigmoid(q_action), axis=2)
          q_action = jax.nn.logsumexp(a = q_action, b = self.classifier_weights[-1,...,0], axis=2)
        
        actor_loss = alpha * log_prob - jnp.diag(q_action)

        assert 0.0 <= config.bc_coef <= 1.0
        if config.bc_coef > 0:
          orig_action = transitions.action
          if config.random_goals == 0.5:
            orig_action = jnp.concatenate([orig_action, orig_action], axis=0)

          bc_loss = -1.0 * networks.log_prob(dist_params, orig_action)
          actor_loss = (config.bc_coef * bc_loss
                        + (1 - config.bc_coef) * actor_loss)

      return jnp.mean(actor_loss)

    alpha_grad = jax.value_and_grad(alpha_loss)
    critic_grad = jax.value_and_grad(critic_loss, has_aux=True)
    actor_grad = jax.value_and_grad(actor_loss)

    def update_step(
        state,
        transitions,
        critic_flag
    ):

      key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4)
      if adaptive_entropy_coefficient:
        alpha_loss, alpha_grads = alpha_grad(state.alpha_params,
                                             state.policy_params, transitions,
                                             key_alpha)
        alpha = jnp.exp(state.alpha_params)
      else:
        alpha = config.entropy_coefficient

      if critic_flag and (not config.use_gcbc):
        (critic_loss, critic_metrics), critic_grads = critic_grad(
            state.q_params, state.policy_params, state.target_q_params,
            transitions, key_critic)
        critic_flat_grad, _ = ravel_pytree(critic_grads)

      actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params,
                                           alpha, transitions, key_actor)
      actor_flat_grad, _ = ravel_pytree(actor_grads)

      # Apply policy gradients
      actor_update, policy_optimizer_state = policy_optimizer.update(
          actor_grads, state.policy_optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, actor_update)

      # Apply critic gradients
      if (not critic_flag) or config.use_gcbc:
        metrics = {}
        critic_loss = 0.0
        q_params = state.q_params
        q_optimizer_state = state.q_optimizer_state
        new_target_q_params = state.target_q_params
      else:
        critic_update, q_optimizer_state = q_optimizer.update(
            critic_grads, state.q_optimizer_state)

        q_params = optax.apply_updates(state.q_params, critic_update)

        new_target_q_params = jax.tree_map(
            lambda x, y: x * (1 - config.tau) + y * config.tau,
            state.target_q_params, q_params)
        metrics = critic_metrics

      critic_grad_norm = jnp.linalg.norm(critic_flat_grad)
      actor_grad_norm = jnp.linalg.norm(actor_flat_grad)
      critic_grad_var = jnp.linalg.norm(critic_flat_grad - 100*state.critic_mean_grad)
      actor_grad_var = jnp.linalg.norm(actor_flat_grad - 100*state.actor_mean_grad)
      metrics.update({
          'critic_loss': critic_loss,
          'actor_loss': actor_loss,
          'critic_grad_norm': critic_grad_norm,
          'actor_grad_norm': actor_grad_norm,
          'critic_grad_var': critic_grad_var,
          'actor_grad_var': actor_grad_var,
          'normalized_critic_grad_var': critic_grad_var / (critic_grad_norm + 1e-8),
          'normalized_actor_grad_var': actor_grad_var / (actor_grad_norm + 1e-8),
      })

      new_state = TrainingState(
          critic_mean_grad = 0.99*state.critic_mean_grad + 0.01*critic_flat_grad,
          actor_mean_grad = 0.99*state.actor_mean_grad + 0.01*actor_flat_grad,
          policy_optimizer_state=policy_optimizer_state,
          q_optimizer_state=q_optimizer_state,
          policy_params=policy_params,
          q_params=q_params,
          target_q_params=new_target_q_params,
          key=key,
      )
      if adaptive_entropy_coefficient:
        # Apply alpha gradients
        alpha_update, alpha_optimizer_state = alpha_optimizer.update(
            alpha_grads, state.alpha_optimizer_state)
        alpha_params = optax.apply_updates(state.alpha_params, alpha_update)
        metrics.update({
            'alpha_loss': alpha_loss,
            'alpha': jnp.exp(alpha_params),
        })
        new_state = new_state._replace(
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=alpha_params)

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray,
        time_delta=10.0)

    # Iterator on demonstration transitions.
    self._iterator = iterator

    critic_update_step = utils.process_multiple_batches(functools.partial(update_step, critic_flag = True),
                                                 config.num_sgd_steps_per_step)
    policy_update_step = utils.process_multiple_batches(functools.partial(update_step, critic_flag = False),
                                                 config.num_sgd_steps_per_step)
    # Use the JIT compiler.
    if config.jit:
      self._critic_update_step = jax.jit(critic_update_step)
      self._policy_update_step = jax.jit(policy_update_step)
    else:
      self._critic_update_step = critic_update_step
      self._policy_update_step = policy_update_step

    def make_initial_state(key):
      """Initialises the training state (parameters and optimiser state)."""
      key_policy, key_q, key = jax.random.split(key, 3)

      policy_params = networks.policy_network.init(key_policy)
      policy_optimizer_state = policy_optimizer.init(policy_params)

      q_params = networks.q_network.init(key_q)
      q_optimizer_state = q_optimizer.init(q_params)
      
      
      q_flat_params, _ = ravel_pytree(q_params)
      policy_flat_params, _ = ravel_pytree(policy_params)

      state = TrainingState(
          critic_mean_grad = jnp.zeros_like(q_flat_params),
          actor_mean_grad = jnp.zeros_like(policy_flat_params),
          policy_optimizer_state=policy_optimizer_state,
          q_optimizer_state=q_optimizer_state,
          policy_params=policy_params,
          q_params=q_params,
          target_q_params=q_params,
          key=key)

      if adaptive_entropy_coefficient:
        state = state._replace(alpha_optimizer_state=alpha_optimizer_state,
                               alpha_params=log_alpha)
      return state

    # Create initial state.
    self._state = make_initial_state(rng)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online
    # and fill the replay buffer.
    self._timestamp = None

  def step(self):
    with jax.profiler.StepTraceAnnotation('step', step_num=self._counter):
      sample = next(self._iterator)
      transitions = types.Transition(*sample.data)
      self._state, metrics = self._critic_update_step(self._state, transitions)
      for _ in range(self.policy_frequency-1):
        self._state, _metrics = self._policy_update_step(self._state, transitions)
        metrics['actor_loss'] += _metrics['actor_loss']
      metrics['actor_loss'] /= self.policy_frequency

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(steps=1, walltime=elapsed_time)
    if elapsed_time > 0:
      metrics['steps_per_second'] = (
          self._num_sgd_steps_per_step / elapsed_time)
    else:
      metrics['steps_per_second'] = 0.

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})

  def get_variables(self, names):
    variables = {
        'policy': self._state.policy_params,
        'critic': self._state.q_params,
    }
    return [variables[name] for name in names]

  def save(self):
    return self._state

  def restore(self, state):
    self._state = state
