from collections import OrderedDict
from copy import deepcopy
from functools import partial

from ml_collections import ConfigDict

import numpy as np
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax
import distrax

from .jax_utils import (
    next_rng, value_and_multi_grad, mse_loss, weighted_mse_loss, weighted_sum_mse_loss, JaxRNG, wrap_function_with_rng,
    collect_jax_metrics
)
from .model import Scalar, update_target_network
from .utils import prefix_metrics


class ConservativeSAC(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.discount = 0.99
        config.alpha_multiplier = 1.0
        config.use_automatic_entropy_tuning = True
        config.backup_entropy = False
        config.target_entropy = 0.0
        config.policy_lr = 3e-4
        config.qf_lr = 3e-4
        config.optimizer_type = 'adam'
        config.soft_target_update_rate = 5e-3
        config.use_cql = True
        config.cql_n_actions = 10
        config.cql_importance_sample = True
        config.cql_lagrange = False
        config.cql_target_action_gap = 1.0
        config.cql_temp = 1.0
        config.cql_min_q_weight = 5.0
        config.cql_max_target_backup = False
        config.cql_clip_diff_min = -np.inf
        config.cql_clip_diff_max = np.inf

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, policy, qf):
        self.config = self.get_default_config(config)
        self.policy = policy
        self.qf = qf
        self.observation_dim = policy.observation_dim
        self.action_dim = policy.action_dim

        self._train_states = {}

        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]
        self.optimizer_class = optimizer_class

        policy_params = self.policy.init(
            next_rng(self.policy.rng_keys()),
            jnp.zeros((10, self.observation_dim))
        )
        self._train_states['policy'] = TrainState.create(
            params=policy_params,
            tx=optimizer_class(self.config.policy_lr),
            apply_fn=None
        )

        qf1_params = self.qf.init(
            next_rng(self.qf.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self._train_states['qf1'] = TrainState.create(
            params=qf1_params,
            tx=optimizer_class(self.config.qf_lr),
            apply_fn=None,
        )
        qf2_params = self.qf.init(
            next_rng(self.qf.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self._train_states['qf2'] = TrainState.create(
            params=qf2_params,
            tx=optimizer_class(self.config.qf_lr),
            apply_fn=None,
        )
        self._target_qf_params = deepcopy({'qf1': qf1_params, 'qf2': qf2_params})

        model_keys = ['policy', 'qf1', 'qf2']

        if self.config.use_automatic_entropy_tuning:
            self.log_alpha = Scalar(0.0)
            self._train_states['log_alpha'] = TrainState.create(
                params=self.log_alpha.init(next_rng()),
                tx=optimizer_class(self.config.policy_lr),
                apply_fn=None
            )
            model_keys.append('log_alpha')

        if self.config.cql_lagrange:
            self.log_alpha_prime = Scalar(1.0)
            self._train_states['log_alpha_prime'] = TrainState.create(
                params=self.log_alpha_prime.init(next_rng()),
                tx=optimizer_class(self.config.qf_lr),
                apply_fn=None
            )
            model_keys.append('log_alpha_prime')

        self._model_keys = tuple(model_keys)
        self._total_steps = 0

    def train(self, batch, bc=False, **kwargs):
        self._total_steps += 1
        self._train_states, self._target_qf_params, metrics = self._train_step(
            self._train_states, self._target_qf_params, next_rng(), batch, bc
        )
        return metrics

    @partial(jax.jit, static_argnames=('self', 'bc'))
    def _train_step(self, train_states, target_qf_params, rng, batch, bc=False):
        rng_generator = JaxRNG(rng)

        def loss_fn(train_params):
            observations = batch['observations']
            actions = batch['actions']
            rewards = batch['rewards']
            next_observations = batch['next_observations']
            dones = batch['dones']

            loss_collection = {}

            @wrap_function_with_rng(rng_generator())
            def forward_policy(rng, *args, **kwargs):
                return self.policy.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.policy.rng_keys())
                )

            @wrap_function_with_rng(rng_generator())
            def forward_qf(rng, *args, **kwargs):
                return self.qf.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.qf.rng_keys())
                )

            new_actions, log_pi = forward_policy(train_params['policy'], observations)

            if self.config.use_automatic_entropy_tuning:
                alpha_loss = -self.log_alpha.apply(train_params['log_alpha']) * (log_pi + self.config.target_entropy).mean()
                loss_collection['log_alpha'] = alpha_loss
                alpha = jnp.exp(self.log_alpha.apply(train_params['log_alpha'])) * self.config.alpha_multiplier
            else:
                alpha_loss = 0.0
                alpha = self.config.alpha_multiplier

            """ Policy loss """
            if bc:
                log_probs = forward_policy(train_params['policy'], observations, actions, method=self.policy.log_prob)
                policy_loss = (alpha*log_pi - log_probs).mean()
            else:
                q_new_actions = jnp.minimum(
                    forward_qf(train_params['qf1'], observations, new_actions),
                    forward_qf(train_params['qf2'], observations, new_actions),
                )
                policy_loss = (alpha*log_pi - q_new_actions).mean()

            loss_collection['policy'] = policy_loss

            """ Q function loss """
            q1_pred = forward_qf(train_params['qf1'], observations, actions)
            q2_pred = forward_qf(train_params['qf2'], observations, actions)

            if self.config.cql_max_target_backup:
                new_next_actions, next_log_pi = forward_policy(
                    train_params['policy'], next_observations, repeat=self.config.cql_n_actions
                )
                target_q_values = jnp.minimum(
                    forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                    forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
                )
                max_target_indices = jnp.expand_dims(jnp.argmax(target_q_values, axis=-1), axis=-1)
                target_q_values = jnp.take_along_axis(target_q_values, max_target_indices, axis=-1).squeeze(-1)
                next_log_pi = jnp.take_along_axis(next_log_pi, max_target_indices, axis=-1).squeeze(-1)
            else:
                new_next_actions, next_log_pi = forward_policy(
                    train_params['policy'], next_observations
                )
                target_q_values = jnp.minimum(
                    forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                    forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
                )

            if self.config.backup_entropy:
                target_q_values = target_q_values - alpha * next_log_pi

            td_target = jax.lax.stop_gradient(
                rewards + (1. - dones) * self.config.discount * target_q_values
            )
            qf1_loss = mse_loss(q1_pred, td_target)
            qf2_loss = mse_loss(q2_pred, td_target)

            ### CQL
            if self.config.use_cql:
                batch_size = actions.shape[0]
                cql_random_actions = jax.random.uniform(
                    rng_generator(), shape=(batch_size, self.config.cql_n_actions, self.action_dim),
                    minval=-1.0, maxval=1.0
                )

                cql_current_actions, cql_current_log_pis = forward_policy(
                    train_params['policy'], observations, repeat=self.config.cql_n_actions,
                )
                cql_next_actions, cql_next_log_pis = forward_policy(
                    train_params['policy'], next_observations, repeat=self.config.cql_n_actions,
                )

                cql_q1_rand = forward_qf(train_params['qf1'], observations, cql_random_actions)
                cql_q2_rand = forward_qf(train_params['qf2'], observations, cql_random_actions)
                cql_q1_current_actions = forward_qf(train_params['qf1'], observations, cql_current_actions)
                cql_q2_current_actions = forward_qf(train_params['qf2'], observations, cql_current_actions)
                cql_q1_next_actions = forward_qf(train_params['qf1'], observations, cql_next_actions)
                cql_q2_next_actions = forward_qf(train_params['qf2'], observations, cql_next_actions)

                cql_cat_q1 = jnp.concatenate(
                    [cql_q1_rand, jnp.expand_dims(q1_pred, 1), cql_q1_next_actions, cql_q1_current_actions], axis=1
                )
                cql_cat_q2 = jnp.concatenate(
                    [cql_q2_rand, jnp.expand_dims(q2_pred, 1), cql_q2_next_actions, cql_q2_current_actions], axis=1
                )
                cql_std_q1 = jnp.std(cql_cat_q1, axis=1)
                cql_std_q2 = jnp.std(cql_cat_q2, axis=1)

                if self.config.cql_importance_sample:
                    random_density = np.log(0.5 ** self.action_dim)
                    cql_cat_q1 = jnp.concatenate(
                        [cql_q1_rand - random_density,
                         cql_q1_next_actions - cql_next_log_pis,
                         cql_q1_current_actions - cql_current_log_pis],
                        axis=1
                    )
                    cql_cat_q2 = jnp.concatenate(
                        [cql_q2_rand - random_density,
                         cql_q2_next_actions - cql_next_log_pis,
                         cql_q2_current_actions - cql_current_log_pis],
                        axis=1
                    )

                cql_qf1_ood = (
                    jax.scipy.special.logsumexp(cql_cat_q1 / self.config.cql_temp, axis=1)
                    * self.config.cql_temp
                )
                cql_qf2_ood = (
                    jax.scipy.special.logsumexp(cql_cat_q2 / self.config.cql_temp, axis=1)
                    * self.config.cql_temp
                )

                """Subtract the log likelihood of data"""
                cql_qf1_diff = jnp.clip(
                    cql_qf1_ood - q1_pred,
                    self.config.cql_clip_diff_min,
                    self.config.cql_clip_diff_max,
                ).mean()
                cql_qf2_diff = jnp.clip(
                    cql_qf2_ood - q2_pred,
                    self.config.cql_clip_diff_min,
                    self.config.cql_clip_diff_max,
                ).mean()

                if self.config.cql_lagrange:
                    alpha_prime = jnp.clip(
                        jnp.exp(self.log_alpha_prime.apply(train_params['log_alpha_prime'])),
                        a_min=0.0, a_max=1000000.0
                    )
                    cql_min_qf1_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf1_diff - self.config.cql_target_action_gap)
                    cql_min_qf2_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf2_diff - self.config.cql_target_action_gap)

                    alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss)*0.5

                    loss_collection['log_alpha_prime'] = alpha_prime_loss

                else:
                    cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
                    cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
                    alpha_prime_loss = 0.0
                    alpha_prime = 0.0

                qf1_loss = qf1_loss + cql_min_qf1_loss
                qf2_loss = qf2_loss + cql_min_qf2_loss

            loss_collection['qf1'] = qf1_loss
            loss_collection['qf2'] = qf2_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }
        new_target_qf_params = {}
        new_target_qf_params['qf1'] = update_target_network(
            new_train_states['qf1'].params, target_qf_params['qf1'],
            self.config.soft_target_update_rate
        )
        new_target_qf_params['qf2'] = update_target_network(
            new_train_states['qf2'].params, target_qf_params['qf2'],
            self.config.soft_target_update_rate
        )

        metrics = collect_jax_metrics(
            aux_values,
            ['log_pi', 'policy_loss', 'qf1_loss', 'qf2_loss', 'alpha_loss',
             'alpha', 'q1_pred', 'q2_pred', 'target_q_values']
        )

        if self.config.use_cql:
            metrics.update(collect_jax_metrics(
                aux_values,
                ['cql_std_q1', 'cql_std_q2', 'cql_q1_rand', 'cql_q2_rand'
                 'cql_qf1_diff', 'cql_qf2_diff', 'cql_min_qf1_loss',
                 'cql_min_qf2_loss', 'cql_q1_current_actions', 'cql_q2_current_actions'
                 'cql_q1_next_actions', 'cql_q2_next_actions', 'alpha_prime',
                 'alpha_prime_loss'],
                'cql'
            ))

        return new_train_states, new_target_qf_params, metrics

    @property
    def model_keys(self):
        return self._model_keys

    @property
    def train_states(self):
        return self._train_states

    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}

    @property
    def total_steps(self):
        return self._total_steps


class DensityRatioWeightedConservativeSAC(ConservativeSAC):

    def __init__(self, config, policy, qf, 
            discriminator, # w(s,a)
            kl_weight, # \lambda_K
            flow_weight, # \lambda_F
            flow_discount,
            discriminator_temp,
            discriminator_clip_ratio,
            discriminator_lr,
        ):
        super().__init__(config, policy, qf)
        self.kl_weight = kl_weight
        self.flow_weight = flow_weight
        self.flow_discount = flow_discount
        self.discriminator_temp = discriminator_temp
        self.discriminator_clip_ratio = discriminator_clip_ratio
        
        print(f"KL weight={self.kl_weight}")
        print(f"Flow weight={self.flow_weight}")
        print(f"Flow discount={self.flow_discount}")
        print(f"Weight temperature={self.discriminator_temp}")
        print(f"Weight clip ratio={self.discriminator_clip_ratio}")

        self.discriminator = discriminator
        discriminator_params = self.discriminator.init(
            next_rng(self.discriminator.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self._train_states['discriminator'] = TrainState.create(
            params=discriminator_params,
            tx=self.optimizer_class(discriminator_lr),
            apply_fn=None,
        )
        self._model_keys = tuple([*self._model_keys, "discriminator"])


    @partial(jax.jit, static_argnames=('self', 'bc', 'weight_policy'))
    def _train_step(self, train_states, target_qf_params, rng, batch, bc=False, weight_policy=True):
        rng_generator = JaxRNG(rng)

        def loss_fn(train_params):
            observations = batch['observations']
            actions = batch['actions']
            rewards = batch['rewards']
            next_observations = batch['next_observations']
            dones = batch['dones']
            normalized_rewards = batch['normalized_rewards']

            loss_collection = {}

            @wrap_function_with_rng(rng_generator())
            def forward_policy(rng, *args, **kwargs):
                return self.policy.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.policy.rng_keys())
                )

            @wrap_function_with_rng(rng_generator())
            def forward_qf(rng, *args, **kwargs):
                return self.qf.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.qf.rng_keys())
                )

            @wrap_function_with_rng(rng_generator())
            def forward_discriminator(rng, *args, **kwargs):
                sa_logits, s_logits = self.discriminator.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.discriminator.rng_keys())
                )
                return sa_logits, s_logits

            """ Predict weightings """
            sa_logits, s_logits = forward_discriminator(
              train_params['discriminator'], observations, actions)
            _, next_s_logits = forward_discriminator(
              train_params['discriminator'], next_observations, actions)

            # exp(f(s)) exp(g(a|s)) = exp(f(s) + g(a|s))
            sa_clipped_logits = jnp.clip(sa_logits / self.discriminator_temp,
                -(1 + self.discriminator_clip_ratio),
                  (1 + self.discriminator_clip_ratio))
            s_clipped_logits = jnp.clip(s_logits / self.discriminator_temp,
                -(1 + self.discriminator_clip_ratio),
                  (1 + self.discriminator_clip_ratio))
            normalized_ratios = jax.nn.softmax(sa_clipped_logits + s_clipped_logits, axis=0)

            weights = jax.lax.stop_gradient(normalized_ratios)

            new_actions, log_pi = forward_policy(train_params['policy'], observations)

            if self.config.use_automatic_entropy_tuning:
                alpha_loss = -self.log_alpha.apply(train_params['log_alpha']) * ((log_pi + self.config.target_entropy) * weights).sum()
                loss_collection['log_alpha'] = alpha_loss
                alpha = jnp.exp(self.log_alpha.apply(train_params['log_alpha'])) * self.config.alpha_multiplier
            else:
                alpha_loss = 0.0
                alpha = self.config.alpha_multiplier

            """ Policy loss """
            if bc:
                log_probs = forward_policy(train_params['policy'], observations, actions, method=self.policy.log_prob)
                policy_loss = ((alpha*log_pi - log_probs) * weights).sum()
            else:
                q_new_actions = jnp.minimum(
                    forward_qf(train_params['qf1'], observations, new_actions),
                    forward_qf(train_params['qf2'], observations, new_actions),
                )
                if weight_policy:
                    policy_loss = ((alpha*log_pi - q_new_actions) * weights).sum()
                else:
                    policy_loss = (alpha*log_pi - q_new_actions).mean()
            loss_collection['policy'] = policy_loss

            """ Q function loss """
            q1_pred = forward_qf(train_params['qf1'], observations, actions)
            q2_pred = forward_qf(train_params['qf2'], observations, actions)

            if self.config.cql_max_target_backup:
                new_next_actions, next_log_pi = forward_policy(
                    train_params['policy'], next_observations, repeat=self.config.cql_n_actions
                )
                target_q_values = jnp.minimum(
                    forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                    forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
                )
                max_target_indices = jnp.expand_dims(jnp.argmax(target_q_values, axis=-1), axis=-1)
                target_q_values = jnp.take_along_axis(target_q_values, max_target_indices, axis=-1).squeeze(-1)
                next_log_pi = jnp.take_along_axis(next_log_pi, max_target_indices, axis=-1).squeeze(-1)
            else:
                new_next_actions, next_log_pi = forward_policy(
                    train_params['policy'], next_observations
                )
                target_q_values = jnp.minimum(
                    forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                    forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
                )

            if self.config.backup_entropy:
                target_q_values = target_q_values - alpha * next_log_pi

            td_target = jax.lax.stop_gradient(
                rewards + (1. - dones) * self.config.discount * target_q_values
            )
            qf1_loss = weighted_sum_mse_loss(q1_pred, td_target, weights)
            qf2_loss = weighted_sum_mse_loss(q2_pred, td_target, weights)

            ### CQL
            if self.config.use_cql:
                batch_size = actions.shape[0]
                cql_random_actions = jax.random.uniform(
                    rng_generator(), shape=(batch_size, self.config.cql_n_actions, self.action_dim),
                    minval=-1.0, maxval=1.0
                )

                cql_current_actions, cql_current_log_pis = forward_policy(
                    train_params['policy'], observations, repeat=self.config.cql_n_actions,
                )
                cql_next_actions, cql_next_log_pis = forward_policy(
                    train_params['policy'], next_observations, repeat=self.config.cql_n_actions,
                )

                cql_q1_rand = forward_qf(train_params['qf1'], observations, cql_random_actions)
                cql_q2_rand = forward_qf(train_params['qf2'], observations, cql_random_actions)
                cql_q1_current_actions = forward_qf(train_params['qf1'], observations, cql_current_actions)
                cql_q2_current_actions = forward_qf(train_params['qf2'], observations, cql_current_actions)
                cql_q1_next_actions = forward_qf(train_params['qf1'], observations, cql_next_actions)
                cql_q2_next_actions = forward_qf(train_params['qf2'], observations, cql_next_actions)

                cql_cat_q1 = jnp.concatenate(
                    [cql_q1_rand, jnp.expand_dims(q1_pred, 1), cql_q1_next_actions, cql_q1_current_actions], axis=1
                )
                cql_cat_q2 = jnp.concatenate(
                    [cql_q2_rand, jnp.expand_dims(q2_pred, 1), cql_q2_next_actions, cql_q2_current_actions], axis=1
                )
                cql_std_q1 = jnp.std(cql_cat_q1, axis=1)
                cql_std_q2 = jnp.std(cql_cat_q2, axis=1)

                if self.config.cql_importance_sample:
                    random_density = np.log(0.5 ** self.action_dim)
                    cql_cat_q1 = jnp.concatenate(
                        [cql_q1_rand - random_density,
                         cql_q1_next_actions - cql_next_log_pis,
                         cql_q1_current_actions - cql_current_log_pis],
                        axis=1
                    )
                    cql_cat_q2 = jnp.concatenate(
                        [cql_q2_rand - random_density,
                         cql_q2_next_actions - cql_next_log_pis,
                         cql_q2_current_actions - cql_current_log_pis],
                        axis=1
                    )

                cql_qf1_ood = (
                    jax.scipy.special.logsumexp(cql_cat_q1 / self.config.cql_temp, axis=1)
                    * self.config.cql_temp
                )
                cql_qf2_ood = (
                    jax.scipy.special.logsumexp(cql_cat_q2 / self.config.cql_temp, axis=1)
                    * self.config.cql_temp
                )

                """Subtract the log likelihood of data"""
                cql_qf1_diff = (jnp.clip(
                    (cql_qf1_ood - q1_pred),
                    self.config.cql_clip_diff_min,
                    self.config.cql_clip_diff_max,
                ) * weights).sum()
                cql_qf2_diff = (jnp.clip(
                    (cql_qf2_ood - q2_pred),
                    self.config.cql_clip_diff_min,
                    self.config.cql_clip_diff_max,
                ) * weights).sum()

                if self.config.cql_lagrange:
                    alpha_prime = jnp.clip(
                        jnp.exp(self.log_alpha_prime.apply(train_params['log_alpha_prime'])),
                        a_min=0.0, a_max=1000000.0
                    )
                    cql_min_qf1_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf1_diff - self.config.cql_target_action_gap)
                    cql_min_qf2_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf2_diff - self.config.cql_target_action_gap)

                    alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss)*0.5

                    loss_collection['log_alpha_prime'] = alpha_prime_loss

                else:
                    cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
                    cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
                    alpha_prime_loss = 0.0
                    alpha_prime = 0.0

                qf1_loss = qf1_loss + cql_min_qf1_loss
                qf2_loss = qf2_loss + cql_min_qf2_loss

            loss_collection['qf1'] = qf1_loss
            loss_collection['qf2'] = qf2_loss

            """ Discriminator training """
            reward_loss = -(normalized_ratios * normalized_rewards).sum()
            flow_consistency_loss = jnp.square(self.flow_discount * jnp.exp((sa_logits + s_logits) / self.discriminator_temp) - jnp.exp(next_s_logits / self.discriminator_temp)).mean()
            kl_penalty_loss = (normalized_ratios * jnp.log(normalized_ratios)).sum()
            discriminator_loss = reward_loss + self.flow_weight * flow_consistency_loss + self.kl_weight * kl_penalty_loss
            loss_collection['discriminator'] = discriminator_loss

            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }
        new_target_qf_params = {}
        new_target_qf_params['qf1'] = update_target_network(
            new_train_states['qf1'].params, target_qf_params['qf1'],
            self.config.soft_target_update_rate
        )
        new_target_qf_params['qf2'] = update_target_network(
            new_train_states['qf2'].params, target_qf_params['qf2'],
            self.config.soft_target_update_rate
        )

        metrics = collect_jax_metrics(
            aux_values,
            ['log_pi', 'policy_loss', 'qf1_loss', 'qf2_loss', 'alpha_loss',
             'alpha', 'q1_pred', 'q2_pred', 'target_q_values',
             "weights", "normalized_ratios", "discriminator_loss",
             "reward_loss", "flow_consistency_loss", "kl_penalty_loss", "next_s_normalized_ratios"]
        )

        if self.config.use_cql:
            metrics.update(collect_jax_metrics(
                aux_values,
                ['cql_std_q1', 'cql_std_q2', 'cql_q1_rand', 'cql_q2_rand'
                 'cql_qf1_diff', 'cql_qf2_diff', 'cql_min_qf1_loss',
                 'cql_min_qf2_loss', 'cql_q1_current_actions', 'cql_q2_current_actions'
                 'cql_q1_next_actions', 'cql_q2_next_actions', 'alpha_prime',
                 'alpha_prime_loss', "weights", "normalized_ratios", "discriminator_loss",
                 "reward_loss", "flow_consistency_loss", "kl_penalty_loss", "next_s_normalized_ratios"],
                'cql'
            ))

        return new_train_states, new_target_qf_params, metrics

    def train(self, batch, bc=False, weight_policy=True):
        self._total_steps += 1
        self._train_states, self._target_qf_params, metrics = self._train_step(
            self._train_states, self._target_qf_params, next_rng(), batch, bc, weight_policy,
        )
        return metrics



class OptdiceDensityRatioWeightedConservativeSAC(ConservativeSAC):

    def __init__(self, config, policy, qf, 
            v_network, 
            e_network,
            alpha,
            v_lr,
            e_lr,
        ):
        super().__init__(config, policy, qf)
        self._alpha = alpha
        print(f"Alpha={alpha}")
        
        self.v_network = v_network
        v_network_params = self.v_network.init(
            next_rng(self.v_network.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
        )
        self._train_states['v_network'] = TrainState.create(
            params=v_network_params,
            tx=self.optimizer_class(v_lr),
            apply_fn=None,
        )
        self._model_keys = tuple([*self._model_keys, "v_network"])

        self.e_network = e_network
        e_network_params = self.e_network.init(
            next_rng(self.e_network.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self._train_states['e_network'] = TrainState.create(
            params=e_network_params,
            tx=self.optimizer_class(e_lr),
            apply_fn=None,
        )
        self._model_keys = tuple([*self._model_keys, "e_network"])


    @partial(jax.jit, static_argnames=('self', 'bc'))
    def _train_step(self, train_states, target_qf_params, rng, batch, bc=False, weight_policy=True):
        rng_generator = JaxRNG(rng)

        def loss_fn(train_params):
            initial_observations = batch["initial_observations"]
            observations = batch['observations']
            actions = batch['actions']
            rewards = batch['rewards']
            next_observations = batch['next_observations']
            dones = batch['dones']
            normalized_rewards = batch['normalized_rewards']

            loss_collection = {}

            @wrap_function_with_rng(rng_generator())
            def forward_policy(rng, *args, **kwargs):
                return self.policy.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.policy.rng_keys())
                )

            @wrap_function_with_rng(rng_generator())
            def forward_qf(rng, *args, **kwargs):
                return self.qf.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.qf.rng_keys())
                )

            @wrap_function_with_rng(rng_generator())
            def forward_v(rng, *args, **kwargs):
                v_values = self.v_network.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.v_network.rng_keys())
                )
                return v_values
            
            @wrap_function_with_rng(rng_generator())
            def forward_e(rng, *args, **kwargs):
                e_values = self.e_network.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.e_network.rng_keys())
                )
                return e_values

            """ Predict weightings and train weighting networks"""

            f_fn = lambda x: x * jnp.log(x + 1e-10)
            f_prime_inv_fn = lambda x: jnp.exp(x - 1)
            g_fn = lambda x: jnp.exp(x - 1) * (x - 1)
            r_fn = lambda x: f_prime_inv_fn(x)
            log_r_fn = lambda x: x - 1
        
            initial_v_values  = forward_v(train_params["v_network"], initial_observations)
            v_values  = forward_v(train_params["v_network"], observations)
            next_v_values = forward_v(train_params["v_network"], next_observations)

            e_v = rewards + (1 - dones) * self.config.discount * next_v_values - v_values
            preactivation_v = (e_v) / self._alpha
            w_v = r_fn(preactivation_v)
            f_w_v = g_fn(preactivation_v)

            e_values = forward_e(train_params["e_network"], observations, actions)
            preactivation_e = (e_values) / self._alpha
            w_e = r_fn(preactivation_e)
            # f_w_e = g_fn(preactivation_e)
            
            v_loss0 = (1 - self.config.discount) * initial_v_values.mean()
            v_loss1 = (- self._alpha * f_w_v).mean()
            v_loss2 = (w_v * (e_v)).mean()
            v_loss = v_loss0 + v_loss1 + v_loss2

            e_loss = ((e_v - e_values) ** 2).mean()

            loss_collection['v_network'] = v_loss
            loss_collection["e_network"] = e_loss

            weights = jax.lax.stop_gradient(w_e)

            new_actions, log_pi = forward_policy(train_params['policy'], observations)

            if self.config.use_automatic_entropy_tuning:
                alpha_loss = -self.log_alpha.apply(train_params['log_alpha']) * ((log_pi + self.config.target_entropy) * weights).sum()
                loss_collection['log_alpha'] = alpha_loss
                alpha = jnp.exp(self.log_alpha.apply(train_params['log_alpha'])) * self.config.alpha_multiplier
            else:
                alpha_loss = 0.0
                alpha = self.config.alpha_multiplier

            """ Policy loss """
            if bc:
                log_probs = forward_policy(train_params['policy'], observations, actions, method=self.policy.log_prob)
                policy_loss = ((alpha*log_pi - log_probs) * weights).sum()
            else:
                q_new_actions = jnp.minimum(
                    forward_qf(train_params['qf1'], observations, new_actions),
                    forward_qf(train_params['qf2'], observations, new_actions),
                )
                policy_loss = ((alpha*log_pi - q_new_actions) * weights).sum()
            loss_collection['policy'] = policy_loss

            """ Q function loss """
            q1_pred = forward_qf(train_params['qf1'], observations, actions)
            q2_pred = forward_qf(train_params['qf2'], observations, actions)

            if self.config.cql_max_target_backup:
                new_next_actions, next_log_pi = forward_policy(
                    train_params['policy'], next_observations, repeat=self.config.cql_n_actions
                )
                target_q_values = jnp.minimum(
                    forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                    forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
                )
                max_target_indices = jnp.expand_dims(jnp.argmax(target_q_values, axis=-1), axis=-1)
                target_q_values = jnp.take_along_axis(target_q_values, max_target_indices, axis=-1).squeeze(-1)
                next_log_pi = jnp.take_along_axis(next_log_pi, max_target_indices, axis=-1).squeeze(-1)
            else:
                new_next_actions, next_log_pi = forward_policy(
                    train_params['policy'], next_observations
                )
                target_q_values = jnp.minimum(
                    forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                    forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
                )

            if self.config.backup_entropy:
                target_q_values = target_q_values - alpha * next_log_pi

            td_target = jax.lax.stop_gradient(
                rewards + (1. - dones) * self.config.discount * target_q_values
            )
            qf1_loss = weighted_sum_mse_loss(q1_pred, td_target, weights)
            qf2_loss = weighted_sum_mse_loss(q2_pred, td_target, weights)

            ### CQL
            if self.config.use_cql:
                batch_size = actions.shape[0]
                cql_random_actions = jax.random.uniform(
                    rng_generator(), shape=(batch_size, self.config.cql_n_actions, self.action_dim),
                    minval=-1.0, maxval=1.0
                )

                cql_current_actions, cql_current_log_pis = forward_policy(
                    train_params['policy'], observations, repeat=self.config.cql_n_actions,
                )
                cql_next_actions, cql_next_log_pis = forward_policy(
                    train_params['policy'], next_observations, repeat=self.config.cql_n_actions,
                )

                cql_q1_rand = forward_qf(train_params['qf1'], observations, cql_random_actions)
                cql_q2_rand = forward_qf(train_params['qf2'], observations, cql_random_actions)
                cql_q1_current_actions = forward_qf(train_params['qf1'], observations, cql_current_actions)
                cql_q2_current_actions = forward_qf(train_params['qf2'], observations, cql_current_actions)
                cql_q1_next_actions = forward_qf(train_params['qf1'], observations, cql_next_actions)
                cql_q2_next_actions = forward_qf(train_params['qf2'], observations, cql_next_actions)

                cql_cat_q1 = jnp.concatenate(
                    [cql_q1_rand, jnp.expand_dims(q1_pred, 1), cql_q1_next_actions, cql_q1_current_actions], axis=1
                )
                cql_cat_q2 = jnp.concatenate(
                    [cql_q2_rand, jnp.expand_dims(q2_pred, 1), cql_q2_next_actions, cql_q2_current_actions], axis=1
                )
                cql_std_q1 = jnp.std(cql_cat_q1, axis=1)
                cql_std_q2 = jnp.std(cql_cat_q2, axis=1)

                if self.config.cql_importance_sample:
                    random_density = np.log(0.5 ** self.action_dim)
                    cql_cat_q1 = jnp.concatenate(
                        [cql_q1_rand - random_density,
                         cql_q1_next_actions - cql_next_log_pis,
                         cql_q1_current_actions - cql_current_log_pis],
                        axis=1
                    )
                    cql_cat_q2 = jnp.concatenate(
                        [cql_q2_rand - random_density,
                         cql_q2_next_actions - cql_next_log_pis,
                         cql_q2_current_actions - cql_current_log_pis],
                        axis=1
                    )

                cql_qf1_ood = (
                    jax.scipy.special.logsumexp(cql_cat_q1 / self.config.cql_temp, axis=1)
                    * self.config.cql_temp
                )
                cql_qf2_ood = (
                    jax.scipy.special.logsumexp(cql_cat_q2 / self.config.cql_temp, axis=1)
                    * self.config.cql_temp
                )

                """Subtract the log likelihood of data"""
                cql_qf1_diff = (jnp.clip(
                    (cql_qf1_ood - q1_pred),
                    self.config.cql_clip_diff_min,
                    self.config.cql_clip_diff_max,
                ) * weights).sum()
                cql_qf2_diff = (jnp.clip(
                    (cql_qf2_ood - q2_pred),
                    self.config.cql_clip_diff_min,
                    self.config.cql_clip_diff_max,
                ) * weights).sum()

                if self.config.cql_lagrange:
                    alpha_prime = jnp.clip(
                        jnp.exp(self.log_alpha_prime.apply(train_params['log_alpha_prime'])),
                        a_min=0.0, a_max=1000000.0
                    )
                    cql_min_qf1_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf1_diff - self.config.cql_target_action_gap)
                    cql_min_qf2_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf2_diff - self.config.cql_target_action_gap)

                    alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss)*0.5

                    loss_collection['log_alpha_prime'] = alpha_prime_loss

                else:
                    cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
                    cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
                    alpha_prime_loss = 0.0
                    alpha_prime = 0.0

                qf1_loss = qf1_loss + cql_min_qf1_loss
                qf2_loss = qf2_loss + cql_min_qf2_loss

            loss_collection['qf1'] = qf1_loss
            loss_collection['qf2'] = qf2_loss

            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params)

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }
        new_target_qf_params = {}
        new_target_qf_params['qf1'] = update_target_network(
            new_train_states['qf1'].params, target_qf_params['qf1'],
            self.config.soft_target_update_rate
        )
        new_target_qf_params['qf2'] = update_target_network(
            new_train_states['qf2'].params, target_qf_params['qf2'],
            self.config.soft_target_update_rate
        )

        metrics = collect_jax_metrics(
            aux_values,
            ['log_pi', 'policy_loss', 'qf1_loss', 'qf2_loss', 'alpha_loss',
             'alpha', 'q1_pred', 'q2_pred', 'target_q_values',
             "weights", "v_loss", "e_loss"]
        )

        if self.config.use_cql:
            metrics.update(collect_jax_metrics(
                aux_values,
                ['cql_std_q1', 'cql_std_q2', 'cql_q1_rand', 'cql_q2_rand'
                 'cql_qf1_diff', 'cql_qf2_diff', 'cql_min_qf1_loss',
                 'cql_min_qf2_loss', 'cql_q1_current_actions', 'cql_q2_current_actions'
                 'cql_q1_next_actions', 'cql_q2_next_actions', 'alpha_prime',
                 'alpha_prime_loss', 
                ],
                'cql'
            ))

        return new_train_states, new_target_qf_params, metrics

