from functools import partial

from ml_collections import ConfigDict

import numpy as np
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax

from .jax_utils import next_rng, JaxRNG, wrap_function_with_rng
from .model import Scalar
from .offline_sac import OfflineSAC


class ConservativeSAC(OfflineSAC):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        
        config.cql_n_actions = 10
        config.cql_importance_sample = True
        config.cql_lagrange = False
        config.cql_target_action_gap = 1.0
        config.cql_temp = 1.0
        config.cql_min_q_weight = 5.0
        config.cql_max_target_backup = False
        config.cql_clip_diff_min = -np.inf
        config.cql_clip_diff_max = np.inf

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config
    
    def __init__(self, config, policy, qf, **kwargs):
        super().__init__(config, policy, qf, **kwargs)
        
        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]
        
        if self.config.cql_lagrange:
            self.log_alpha_prime = Scalar(1.0)
            self._train_states['log_alpha_prime'] = TrainState.create(
                params=self.log_alpha_prime.init(next_rng()),
                tx=optimizer_class(self.config.qf_lr),
                apply_fn=None
            )
            self._model_keys += ('log_alpha_prime',)
    
    def additional_q_loss(self, train_params, target_qf_params,
                          q1, q2,
                          observations, actions, rewards, next_observations, dones,
                          rng, loss_collection, flags=None):
        rng_generator = JaxRNG(rng)
        @wrap_function_with_rng(rng_generator())
        def forward_f(rng, f, *args, **kwargs):
            return f.apply(*args, **kwargs, rngs=JaxRNG(rng)(f.rng_keys()))
        
        batch_size = actions.shape[0]
        cql_random_actions = jax.random.uniform(
            rng_generator(), shape=(batch_size, self.config.cql_n_actions, self.action_dim),
            minval=-1.0, maxval=1.0
        )
        
        cql_current_actions, cql_current_log_pis = forward_f(
            self.policy, train_params['policy'], observations, repeat=self.config.cql_n_actions,
        )
        cql_next_actions, cql_next_log_pis = forward_f(
            self.policy, train_params['policy'], next_observations, repeat=self.config.cql_n_actions,
        )
        
        cql_q1_rand, _ = forward_f(self.qf, train_params['qf1'], observations, cql_random_actions)
        cql_q2_rand, _ = forward_f(self.qf, train_params['qf2'], observations, cql_random_actions)
        cql_q1_current_actions, _ = forward_f(self.qf, train_params['qf1'], observations, cql_current_actions)
        cql_q2_current_actions, _ = forward_f(self.qf, train_params['qf2'], observations, cql_current_actions)
        cql_q1_next_actions, _ = forward_f(self.qf, train_params['qf1'], observations, cql_next_actions)
        cql_q2_next_actions, _ = forward_f(self.qf, train_params['qf2'], observations, cql_next_actions)
        
        cql_cat_q1 = jnp.concatenate(
            [cql_q1_rand, jnp.expand_dims(q1, 1), cql_q1_next_actions, cql_q1_current_actions], axis=1
        )
        cql_cat_q2 = jnp.concatenate(
            [cql_q2_rand, jnp.expand_dims(q2, 1), cql_q2_next_actions, cql_q2_current_actions], axis=1
        )
        cql_std_q1 = jnp.std(cql_cat_q1, axis=1)
        cql_std_q2 = jnp.std(cql_cat_q2, axis=1)
        
        if self.config.cql_importance_sample:
            random_density = np.log(0.5 ** self.action_dim)
            cql_cat_q1 = jnp.concatenate(
                [cql_q1_rand - random_density,
                 cql_q1_next_actions - cql_next_log_pis,
                 cql_q1_current_actions - cql_current_log_pis],
                axis=1
            )
            cql_cat_q2 = jnp.concatenate(
                [cql_q2_rand - random_density,
                 cql_q2_next_actions - cql_next_log_pis,
                 cql_q2_current_actions - cql_current_log_pis],
                axis=1
            )
        
        cql_qf1_ood = (
                jax.scipy.special.logsumexp(cql_cat_q1 / self.config.cql_temp, axis=1)
                * self.config.cql_temp
        )
        cql_qf2_ood = (
                jax.scipy.special.logsumexp(cql_cat_q2 / self.config.cql_temp, axis=1)
                * self.config.cql_temp
        )
        
        """Subtract the log likelihood of data"""
        cql_qf1_diff = cql_qf1_ood - q1
        cql_qf2_diff = cql_qf2_ood - q2
        
        cql_qf1_diff = jnp.clip(cql_qf1_diff, self.config.cql_clip_diff_min,
                                self.config.cql_clip_diff_max).mean()
        cql_qf2_diff = jnp.clip(cql_qf2_diff, self.config.cql_clip_diff_min,
                                self.config.cql_clip_diff_max).mean()
        
        if self.config.cql_lagrange:
            alpha_prime = jnp.clip(
                jnp.exp(self.log_alpha_prime.apply(train_params['log_alpha_prime'])),
                a_min=0.0, a_max=1000000.0
            )
            cql_min_qf1_loss = alpha_prime * self.config.cql_min_q_weight * (
                    cql_qf1_diff - self.config.cql_target_action_gap)
            cql_min_qf2_loss = alpha_prime * self.config.cql_min_q_weight * (
                    cql_qf2_diff - self.config.cql_target_action_gap)
            
            alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss) * 0.5
            
            loss_collection['log_alpha_prime'] = alpha_prime_loss
        
        else:
            cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
            cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
            alpha_prime_loss = 0.0
            alpha_prime = 0.0
        
        log_items = {
            'cql/cql_std_q1': cql_std_q1.mean(),
            'cql/cql_std_q2': cql_std_q2.mean(),
            'cql/cql_q1_rand': cql_q1_rand.mean(),
            'cql/cql_q2_rand': cql_q2_rand.mean(),
            'cql/cql_qf1_diff': cql_qf1_diff.mean(),
            'cql/cql_qf2_diff': cql_qf2_diff.mean(),
            'cql/cql_min_qf1_loss': cql_min_qf1_loss.mean(),
            'cql/cql_min_qf2_loss': cql_min_qf2_loss.mean(),
            'cql/cql_q1_current_actions': cql_q1_current_actions.mean(),
            'cql/cql_q2_current_actions': cql_q2_current_actions.mean(),
            'cql/cql_q1_next_actions': cql_q1_next_actions.mean(),
            'cql/cql_q2_next_actions': cql_q2_next_actions.mean(),
            'cql/alpha_prime': alpha_prime,
            'cql/alpha_prime_loss': alpha_prime_loss,
        }
        
        return cql_min_qf1_loss, cql_min_qf2_loss, log_items
