from copy import deepcopy
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from ml_collections import ConfigDict

from algos.model import Scalar, update_target_network
from utilities.jax_utils import mse_loss, next_rng, value_and_multi_grad
from utilities.utils import prefix_metrics


class ConservativeSAC(object):

  @staticmethod
  def get_default_config(updates=None):
    config = ConfigDict()
    config.nstep = 1
    config.discount = 0.99
    config.alpha_multiplier = 1.0
    config.use_automatic_entropy_tuning = True
    config.backup_entropy = False
    config.target_entropy = 0.0
    config.encoder_lr = 3e-4
    config.policy_lr = 3e-4
    config.qf_lr = 3e-4
    config.optimizer_type = 'adam'
    config.soft_target_update_rate = 5e-3
    config.use_cql = True
    config.cql_n_actions = 10
    config.cql_importance_sample = True
    # config.cql_importance_sample = False
    config.cql_lagrange = False
    config.cql_target_action_gap = 1.0
    config.cql_temp = 1.0
    config.cql_min_q_weight = 5.0
    config.cql_coff = 1.0
    config.cql_max_target_backup = False
    config.cql_clip_diff_min = -np.inf
    config.cql_clip_diff_max = np.inf
    config.bc_mode = 'mse'  # 'mle'
    config.bc_weight = 0.
    config.res_hidden_size = 1024
    config.encoder_blocks = 1
    config.head_blocks = 1

    if updates is not None:
      config.update(ConfigDict(updates).copy_and_resolve_references())
    return config

  def __init__(self, config, encoder, policy, qf, decoupled_q=False):
    self.config = self.get_default_config(config)
    self.decoupled_q = decoupled_q
    self.policy = policy
    self.qf = qf
    self.encoder = encoder
    self.observation_dim = policy.input_size
    self.embedding_dim = policy.embedding_dim
    self.action_dim = policy.action_dim

    self._train_states = {}

    optimizer_class = {
      'adam': optax.adam,
      'sgd': optax.sgd,
    }[self.config.optimizer_type]

    encoder_params = self.encoder.init(
      next_rng(), jnp.zeros((10, self.policy.observation_dim))
    )
    self._train_states['encoder'] = TrainState.create(
      params=encoder_params,
      tx=optimizer_class(self.config.encoder_lr),
      apply_fn=None
    )

    policy_params = self.policy.init(
      next_rng(), next_rng(), jnp.zeros((10, self.embedding_dim))
    )
    self._train_states['policy'] = TrainState.create(
      params=policy_params,
      tx=optimizer_class(self.config.policy_lr),
      apply_fn=None
    )

    qf1_params = self.qf.init(
      next_rng(), jnp.zeros((10, self.embedding_dim)),
      jnp.zeros((10, self.action_dim))
    )
    self._train_states['qf1'] = TrainState.create(
      params=qf1_params,
      tx=optimizer_class(self.config.qf_lr),
      apply_fn=None,
    )
    qf2_params = self.qf.init(
      next_rng(), jnp.zeros((10, self.embedding_dim)),
      jnp.zeros((10, self.action_dim))
    )
    self._train_states['qf2'] = TrainState.create(
      params=qf2_params,
      tx=optimizer_class(self.config.qf_lr),
      apply_fn=None,
    )
    self._target_qf_params = deepcopy({'qf1': qf1_params, 'qf2': qf2_params})

    model_keys = ['policy', 'qf1', 'qf2', 'encoder']
    self.actor_model_keys = ['policy', 'encoder']
    self.critic_model_keys = ['qf1', 'qf2', 'encoder']

    if self.config.use_automatic_entropy_tuning:
      self.log_alpha = Scalar(0.0)
      self._train_states['log_alpha'] = TrainState.create(
        params=self.log_alpha.init(next_rng()),
        tx=optimizer_class(self.config.policy_lr),
        apply_fn=None
      )
      model_keys.append('log_alpha')
      self.actor_model_keys.append('log_alpha')

    if self.config.cql_lagrange:
      self.log_alpha_prime = Scalar(1.0)
      self._train_states['log_alpha_prime'] = TrainState.create(
        params=self.log_alpha_prime.init(next_rng()),
        tx=optimizer_class(self.config.qf_lr),
        apply_fn=None
      )
      model_keys.append('log_alpha_prime')
      self.critic_model_keys.append('log_alpha_prime')

    self._model_keys = tuple(model_keys)
    self._total_steps = 0

  def train(self, batch, weight_eval, weight_improve, weight_constraint, bc=False):
    self._total_steps += 1
    # update critic
    self._train_states, metrics1 = self._train_critic_step(
      self._train_states, self._target_qf_params, next_rng(), batch, weight_eval, weight_constraint, bc
    )
    # update actor
    self._train_states, metrics2 = self._train_actor_step(
      self._train_states, self._target_qf_params, next_rng(), batch, weight_improve, bc
    )
    # update target critic
    self._target_qf_params = self._update_target(self.train_states, self._target_qf_params)
    return {**metrics1, **metrics2}

  @partial(jax.jit, static_argnames=('self', 'weight_eval', 'weight_improve', 'weight_constraint', 'bc'))
  def _train_critic_step(self, train_states, target_qf_params, rng, batch, weight_eval, weight_constraint, bc=False):
    def loss_fn(train_params, rng):
      observations = batch['observations']
      actions = batch['actions']
      rewards = batch['rewards']
      next_observations = batch['next_observations']
      dones = batch['dones']
      weights = batch['weights']

      loss_collection = {}

      rng, split_rng = jax.random.split(rng)
      embedding = self.encoder.apply(
        train_params['encoder'], observations
      )
      next_embedding = self.encoder.apply(
        train_params['encoder'], next_observations
      )

      if self.config.use_automatic_entropy_tuning:
        # alpha = jnp.exp(
        #   self.log_alpha.apply(train_params['log_alpha'])
        # ) * self.config.alpha_multiplier
        alpha = jnp.exp(
          self.log_alpha.apply(train_states['log_alpha'].params)
        ) * self.config.alpha_multiplier
      else:
        alpha = self.config.alpha_multiplier

      """ Q function loss """
      q1_pred = self.qf.apply(train_params['qf1'], embedding, actions)
      q2_pred = self.qf.apply(train_params['qf2'], embedding, actions)

      rng, split_rng = jax.random.split(rng)
      if self.config.cql_max_target_backup:
        new_next_actions, next_log_pi = self.policy.apply(
          # train_params['policy'],
          train_states['policy'].params,
          split_rng,
          next_embedding,
          repeat=self.config.cql_n_actions
        )
        target_q_values = jnp.minimum(
          self.qf.apply(
            target_qf_params['qf1'], next_embedding, new_next_actions
          ),
          self.qf.apply(
            target_qf_params['qf2'], next_embedding, new_next_actions
          ),
        )
        max_target_indices = jnp.expand_dims(
          jnp.argmax(target_q_values, axis=-1), axis=-1
        )
        target_q_values = jnp.take_along_axis(
          target_q_values, max_target_indices, axis=-1
        ).squeeze(-1)
        next_log_pi = jnp.take_along_axis(
          next_log_pi, max_target_indices, axis=-1
        ).squeeze(-1)
      else:
        new_next_actions, next_log_pi = self.policy.apply(
          # train_params['policy'], split_rng, next_embedding
          train_states['policy'].params, split_rng, next_embedding
        )
        target_q_values = jnp.minimum(
          self.qf.apply(
            target_qf_params['qf1'], next_embedding, new_next_actions
          ),
          self.qf.apply(
            target_qf_params['qf2'], next_embedding, new_next_actions
          ),
        )

      if self.config.backup_entropy:
        target_q_values = target_q_values - alpha * next_log_pi

      discount = self.config.discount**self.config.nstep
      td_target = jax.lax.stop_gradient(
        rewards + (1. - dones) * discount * target_q_values
      )
      if weight_eval:
        qf1_loss = mse_loss(q1_pred, td_target, weights)
        qf2_loss = mse_loss(q2_pred, td_target, weights)
      else:
        qf1_loss = mse_loss(q1_pred, td_target)
        qf2_loss = mse_loss(q2_pred, td_target)

      # CQL
      if self.config.use_cql:
        batch_size = actions.shape[0]
        rng, split_rng = jax.random.split(rng)
        cql_random_actions = jax.random.uniform(
          split_rng,
          shape=(batch_size, self.config.cql_n_actions, self.action_dim),
          minval=-1.0,
          maxval=1.0
        )

        rng, split_rng = jax.random.split(rng)
        cql_current_actions, cql_current_log_pis = self.policy.apply(
          # train_params['policy'],
          train_states['policy'].params,
          split_rng,
          embedding,
          repeat=self.config.cql_n_actions
        )
        rng, split_rng = jax.random.split(rng)
        cql_next_actions, cql_next_log_pis = self.policy.apply(
          train_states['policy'].params,
          split_rng,
          next_embedding,
          repeat=self.config.cql_n_actions
        )

        cql_q1_rand = self.qf.apply(
          train_params['qf1'], embedding, cql_random_actions
        )
        cql_q2_rand = self.qf.apply(
          train_params['qf2'], embedding, cql_random_actions
        )
        cql_q1_current_actions = self.qf.apply(
          train_params['qf1'], embedding, cql_current_actions
        )
        cql_q2_current_actions = self.qf.apply(
          train_params['qf2'], embedding, cql_current_actions
        )
        cql_q1_next_actions = self.qf.apply(
          train_params['qf1'], embedding, cql_next_actions
        )
        cql_q2_next_actions = self.qf.apply(
          train_params['qf2'], embedding, cql_next_actions
        )

        cql_cat_q1 = jnp.concatenate(
          [
            cql_q1_rand,
            jnp.expand_dims(q1_pred, 1), cql_q1_next_actions,
            cql_q1_current_actions
          ],
          axis=1
        )
        cql_cat_q2 = jnp.concatenate(
          [
            cql_q2_rand,
            jnp.expand_dims(q2_pred, 1), cql_q2_next_actions,
            cql_q2_current_actions
          ],
          axis=1
        )
        cql_std_q1 = jnp.std(cql_cat_q1, axis=1)
        cql_std_q2 = jnp.std(cql_cat_q2, axis=1)
        # if weight_constraint:
        #   q1_pred *= weights
        #   q2_pred *= weights
        if self.config.cql_importance_sample:
          random_density = np.log(0.5**self.action_dim)
          cql_cat_q1 = jnp.concatenate(
            [
              cql_q1_rand - random_density, cql_q1_next_actions -
              cql_next_log_pis, cql_q1_current_actions - cql_current_log_pis
            ],
            axis=1
          )
          cql_cat_q2 = jnp.concatenate(
            [
              cql_q2_rand - random_density, cql_q2_next_actions -
              cql_next_log_pis, cql_q2_current_actions - cql_current_log_pis
            ],
            axis=1
          )
          # cql_cat_q1 = jnp.concatenate(
          #   [
          #     jnp.expand_dims(q1_pred, axis=1),
          #     cql_q1_rand - random_density, cql_q1_next_actions -
          #     cql_next_log_pis, cql_q1_current_actions - cql_current_log_pis
          #   ],
          #   axis=1
          # )
          # cql_cat_q2 = jnp.concatenate(
          #   [
          #     jnp.expand_dims(q2_pred, axis=1),
          #     cql_q2_rand - random_density, cql_q2_next_actions -
          #     cql_next_log_pis, cql_q2_current_actions - cql_current_log_pis
          #   ],
          #   axis=1
          # )
        else:
          cql_cat_q1 = jnp.concatenate(
            [
              cql_q1_rand, 
              jnp.expand_dims(q1_pred, axis=1), cql_q1_next_actions, cql_q1_current_actions
            ],
            axis=1
          )
          cql_cat_q2 = jnp.concatenate(
            [
              cql_q2_rand, 
              jnp.expand_dims(q2_pred, axis=1), cql_q2_next_actions, cql_q2_current_actions
            ],
            axis=1
          )
          
        cql_qf1_ood = (
          jax.scipy.special
          .logsumexp(cql_cat_q1 / self.config.cql_temp, axis=1) *
          self.config.cql_temp
        )
        cql_qf2_ood = (
          jax.scipy.special
          .logsumexp(cql_cat_q2 / self.config.cql_temp, axis=1) *
          self.config.cql_temp
        )
        if weight_constraint==1:
          q1_pred *= weights
          q2_pred *= weights
          cql_qf1_ood *= self.config.cql_coff
          cql_qf2_ood *= self.config.cql_coff

        """Subtract the log likelihood of data"""
        cql_qf1_diff = jnp.clip(
          cql_qf1_ood - q1_pred,
          self.config.cql_clip_diff_min,
          self.config.cql_clip_diff_max,
        )
        cql_qf2_diff = jnp.clip(
          cql_qf2_ood - q2_pred,
          self.config.cql_clip_diff_min,
          self.config.cql_clip_diff_max,
        )
        if weight_constraint==2:
          cql_qf1_diff *= weights
          cql_qf2_diff *= weights
        cql_qf1_diff = cql_qf1_diff.mean()
        cql_qf2_diff = cql_qf2_diff.mean()

        if self.config.cql_lagrange:
          alpha_prime = jnp.clip(
            jnp.exp(
              self.log_alpha_prime.apply(train_params['log_alpha_prime'])
            ),
            a_min=0.0,
            a_max=1000000.0
          )
          cql_min_qf1_loss = alpha_prime * self.config.cql_min_q_weight * (
            cql_qf1_diff - self.config.cql_target_action_gap
          )
          cql_min_qf2_loss = alpha_prime * self.config.cql_min_q_weight * (
            cql_qf2_diff - self.config.cql_target_action_gap
          )

          alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss) * 0.5

          loss_collection['log_alpha_prime'] = alpha_prime_loss # loss for log_alpha_prime grad

        else:
          cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
          cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
          alpha_prime_loss = 0.0
          alpha_prime = 0.0

        qf1_loss = qf1_loss + cql_min_qf1_loss
        qf2_loss = qf2_loss + cql_min_qf2_loss

      loss_collection['qf1'] = qf1_loss # loss for qf1 grad
      loss_collection['qf2'] = qf2_loss # loss for qf2 grad
      loss_collection['encoder'] = (  # loss for encoder grad
        # loss_collection['policy'] +
        loss_collection['qf1'] +
        loss_collection['qf2']
      ) / 3
      return tuple(loss_collection[key] for key in self.critic_model_keys), locals()

    # get grad
    
    train_params = {key: train_states[key].params for key in self.critic_model_keys}
    (_, aux_values), grads = value_and_multi_grad(
      loss_fn, len(self.critic_model_keys), has_aux=True
    )(train_params, rng)

    # update state
    new_train_states = train_states
    for i, key in enumerate(self.critic_model_keys):
      new_train_states[key] = train_states[key].apply_gradients(grads=grads[i][key])

    metrics = dict(
      qf1_loss=aux_values['qf1_loss'],
      qf2_loss=aux_values['qf2_loss'],
      average_qf1=aux_values['q1_pred'].mean(),
      average_qf2=aux_values['q2_pred'].mean(),
      average_target_q=aux_values['target_q_values'].mean(),
    )

    if self.config.use_cql:
      metrics.update(
        prefix_metrics(
          dict(
            cql_std_q1=aux_values['cql_std_q1'].mean(),
            cql_std_q2=aux_values['cql_std_q2'].mean(),
            cql_q1_rand=aux_values['cql_q1_rand'].mean(),
            cql_q2_rand=aux_values['cql_q2_rand'].mean(),
            cql_qf1_diff=aux_values['cql_qf1_diff'].mean(),
            cql_qf2_diff=aux_values['cql_qf2_diff'].mean(),
            cql_min_qf1_loss=aux_values['cql_min_qf1_loss'].mean(),
            cql_min_qf2_loss=aux_values['cql_min_qf2_loss'].mean(),
            cql_q1_current_actions=aux_values['cql_q1_current_actions'].mean(),
            cql_q2_current_actions=aux_values['cql_q2_current_actions'].mean(),
            cql_q1_next_actions=aux_values['cql_q1_next_actions'].mean(),
            cql_q2_next_actions=aux_values['cql_q2_next_actions'].mean(),
            alpha_prime=aux_values['alpha_prime'],
            alpha_prime_loss=aux_values['alpha_prime_loss'],
          ), 'cql'
        )
      )
    return new_train_states, metrics
  
  @partial(jax.jit, static_argnames=('self'))
  def _update_target(self, train_states, target_qf_params):
    # update target critic
    new_target_qf_params = {}
    new_target_qf_params['qf1'] = update_target_network(
      train_states['qf1'].params, target_qf_params['qf1'],
      self.config.soft_target_update_rate
    )
    new_target_qf_params['qf2'] = update_target_network(
      train_states['qf2'].params, target_qf_params['qf2'],
      self.config.soft_target_update_rate
    )
    return new_target_qf_params

  @partial(jax.jit, static_argnames=('self', 'weight_eval', 'weight_improve', 'bc'))
  def _train_actor_step(self, train_states, target_qf_params, rng, batch, weight_improve, bc=False):
    def loss_fn(train_params, rng):
      observations = batch['observations']
      actions = batch['actions']
      dones = batch['dones']
      weights = batch['weights']

      loss_collection = {}

      rng, split_rng = jax.random.split(rng)
      embedding = self.encoder.apply(
        train_params['encoder'], observations
      )
      new_actions, log_pi = self.policy.apply(
        train_params['policy'], split_rng, embedding 
      )

      if self.config.use_automatic_entropy_tuning:
        alpha_loss = -self.log_alpha.apply(train_params['log_alpha']) * (
          log_pi + self.config.target_entropy
        ).mean()
        loss_collection['log_alpha'] = alpha_loss # loss for log_alpha grad
        alpha = jnp.exp(
          self.log_alpha.apply(train_params['log_alpha'])
        ) * self.config.alpha_multiplier
      else:
        alpha_loss = 0.0
        alpha = self.config.alpha_multiplier
      """ Policy loss """
      # get bc loss
      if self.config.bc_mode == 'mle':
        log_probs = self.policy.apply(
          train_params['policy'],
          embedding,
          actions,
          method=self.policy.log_prob
        )
        bc_loss = (alpha * log_pi - log_probs).mean()
      elif self.config.bc_mode == 'mse':
        bc_loss = mse_loss(actions, new_actions)
      else:
        raise RuntimeError('{} not implemented!'.format(self.config.bc_mode))

      # get (offline)rl loss
      if bc:
        log_probs = self.policy.apply(
          train_params['policy'],
          embedding,
          actions,
          method=self.policy.log_prob
        )
        rl_loss = (alpha * log_pi - log_probs).mean()
      else:
        q_new_actions = jnp.minimum(
          # self.qf.apply(train_params['qf1'], embedding, new_actions),
          # self.qf.apply(train_params['qf2'], embedding, new_actions),
          self.qf.apply(train_states['qf1'].params, embedding, new_actions),
          self.qf.apply(train_states['qf2'].params, embedding, new_actions),
        )
        if weight_improve:
          q_new_actions*= weights
        rl_loss = alpha * log_pi - q_new_actions
        # if weight_improve:
        #   rl_loss *= weights
        rl_loss = rl_loss.mean()

      # total loss for policy
      policy_loss = rl_loss + self.config.bc_weight * bc_loss
      loss_collection['policy'] = policy_loss  # loss for policy grad
      loss_collection['bc_loss'] = bc_loss
      loss_collection['rl_loss'] = rl_loss
      loss_collection['encoder'] = ( # loss for encoder grad
        loss_collection['policy']
      ) / 3
      return tuple(loss_collection[key] for key in self.actor_model_keys), locals()

    # get grad
    
    train_params = {key: train_states[key].params for key in self.actor_model_keys}
    (_, aux_values), grads = value_and_multi_grad(
      loss_fn, len(self.actor_model_keys), has_aux=True
    )(train_params, rng)

    # update state
    new_train_states = train_states
    for i, key in enumerate(self.actor_model_keys):
      new_train_states[key] = train_states[key].apply_gradients(grads=grads[i][key])

    metrics = dict(
      log_pi=aux_values['log_pi'].mean(),
      policy_loss=aux_values['policy_loss'], 
      bc_loss=aux_values['bc_loss'],
      rl_loss=aux_values['rl_loss'],
      alpha_loss=aux_values['alpha_loss'],
      alpha=aux_values['alpha'],
    )
    return new_train_states, metrics
  
  @property
  def model_keys(self):
    return self._model_keys

  @property
  def train_states(self):
    return self._train_states

  @property
  def train_params(self):
    return {key: self.train_states[key].params for key in self.model_keys}

  @property
  def total_steps(self):
    return self._total_steps
