from copy import deepcopy
from functools import partial

from ml_collections import ConfigDict

import numpy as np
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax

from .jax_utils import next_rng, JaxRNG, wrap_function_with_rng
from .model import Scalar
from .conservative_sac import ConservativeSAC
from .utils import prefix_metrics


class POPSAC(ConservativeSAC):
    
    @staticmethod
    def get_default_config(updates=None):
        config = ConservativeSAC.get_default_config()
        
        config.dual_rank = 4  # the rank of A and B
        config.pop_gamma = 1.0  # the discount factor for the pop objective
        config.pop_margin = 0.0  # the margin on the pop semi-definite constraint
        config.learn_g = True  # learn the g function as opposed to sampling it
        config.g_lr_gain = 1e0  # gain for the learning rate of g w.r.t the critic
        config.dual_lr_gain = 1e0  # gain for the learning rate of l w.r.t the critic
        config.g_arch = '1024-1024-1024-1024'
        config.dual_min_val = np.nan  # the minimum value of the dual variables A and B
        config.dual_max_val = np.nan  # the maximum value of the dual variables A and B
        config.dual_grad_clip = -1.0  # clip the gradient of the dual variables
        config.pop_start_epoch = -1  # the number of epochs before starting to learn pop coef
        config.backprop_dual_grad = False  # backprop the gradient of the dual variables
        config.beta_multiplier = 1.0
        config.use_automatic_kl_tuning = False
        config.target_kl = 1.0
        config.beta_min = 0
        config.beta_max = np.inf
        config.use_cql = False
        
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config
    
    def __init__(self, config, policy, qf, dualf, gf):
        super().__init__(config, policy, qf)
        self.dualf = dualf
        self.gf = gf
        
        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]
        
        lf_lr = self.config.dual_lr_gain * self.config.qf_lr
        
        tx_chain = []
        if self.config.dual_grad_clip > 0:
            tx_chain.append(optax.clip_by_global_norm(self.config.dual_grad_clip))
        if self.config.optimizer_type == 'adam':
            tx_chain.append(optax.scale_by_adam())
        else:
            tx_chain.append(optax.identity())
        tx_chain.append(optax.scale(-lf_lr))
        
        dualf1_params = self.dualf.init(
            next_rng(self.dualf.rng_keys()),
            jnp.zeros((10, self.qf.feature_dim())),
        )
        self._train_states['dualf1'] = TrainState.create(
            params=dualf1_params,
            tx=optax.chain(*tx_chain),
            apply_fn=None,
        )
        dualf2_params = self.dualf.init(
            next_rng(self.dualf.rng_keys()),
            jnp.zeros((10, self.qf.feature_dim())),
        )
        self._train_states['dualf2'] = TrainState.create(
            params=dualf2_params,
            tx=optax.chain(*tx_chain),
            apply_fn=None,
        )
        self._target_qf_params.update(deepcopy({'dualf1': dualf1_params, 'dualf2': dualf2_params}))
        self._model_keys += ('dualf1', 'dualf2')
        
        if self.config.learn_g:
            gf_lr = self.config.g_lr_gain * self.config.qf_lr
            
            gf1_params = self.gf.init(
                next_rng(self.gf.rng_keys()),
                jnp.zeros((10, self.observation_dim)),
                jnp.zeros((10, self.action_dim))
            )
            self._train_states['gf1'] = TrainState.create(
                params=gf1_params,
                tx=optimizer_class(gf_lr),
                apply_fn=None,
            )
            gf2_params = self.gf.init(
                next_rng(self.gf.rng_keys()),
                jnp.zeros((10, self.observation_dim)),
                jnp.zeros((10, self.action_dim))
            )
            self._train_states['gf2'] = TrainState.create(
                params=gf2_params,
                tx=optimizer_class(gf_lr),
                apply_fn=None,
            )
            self._target_qf_params.update(deepcopy({'gf1': gf1_params, 'gf2': gf2_params}))
            self._model_keys += ('gf1', 'gf2')
        
        if self.config.use_automatic_kl_tuning:
            if self.config.beta_multiplier == np.inf:
                raise ValueError('To use automatic KL tuning, beta_multiplier must be finite')
            
            self.log_beta = Scalar(0.0)
            self._train_states['log_beta'] = TrainState.create(
                params=self.log_beta.init(next_rng()),
                tx=optimizer_class(self.config.policy_lr),
                apply_fn=None
            )
            self._model_keys += ('log_beta',)

    def additional_q_loss(self, train_params, target_qf_params,
                          q1, q2,
                          observations, actions, rewards, next_observations, dones,
                          rng, loss_collection, flags=None):
        if self.config.use_cql:
            return super().additional_q_loss(
                train_params, target_qf_params, q1, q2, observations, actions, rewards, next_observations, dones,
                rng, loss_collection, flags
            )
        else:
            return 0.0, 0.0, {}

    def get_additional_training_flags(self, epoch):
        return {
            'train_pop': epoch > self.config.pop_start_epoch,
        }
    
    def additional_policy_loss(self, train_params, target_qf_params,
                               observations, actions, rewards, next_observations, dones,
                               rng, loss_collection, flags=None):
        if self.config.beta_multiplier == 0.0:
            return 0.0, {}
        else:
            rng_generator = JaxRNG(rng)
            @wrap_function_with_rng(rng_generator())
            def forward_f(rng, f, *args, **kwargs):
                return f.apply(*args, **kwargs, rngs=JaxRNG(rng)(f.rng_keys()))

            def compute_pop_policy_loss(dual_params, features, next_features, dones, g=None):
                m_a, m_b, a, b, a_mag, b_mag = forward_f(self.dualf, dual_params, features)
                next_m_a, _, _, _, _, _ = forward_f(self.dualf, dual_params, next_features)
                feature_term = (m_a ** 2 + m_b ** 2).sum(-1)
                angle_term = (m_b * next_m_a).sum(-1)
                if g is None:
                    g_tilde = angle_term / (a_mag * b_mag + self.config.numerical_eps)
                    g = g_tilde
                margin_term = self.config.pop_margin * (a ** 2 + b ** 2).sum()
                if self.config.beta_multiplier == np.inf:
                    reweight = jnp.ones_like(dones)
                else:
                    reweight = jnp.exp(feature_term + 2 * a_mag * b_mag * (1. - dones) * self.config.pop_gamma * g - margin_term)
                    reweight = reweight / (reweight.mean() + self.config.numerical_eps)
                    reweight = jnp.clip(reweight, self.config.reweighting_coef_min, self.config.reweighting_coef_max)
                pop_policy_loss = -jax.lax.stop_gradient(reweight) * (1. - dones) * self.config.pop_gamma * angle_term
                return pop_policy_loss, reweight

            _, features1 = forward_f(self.qf, train_params['qf1'], observations, actions)
            _, features2 = forward_f(self.qf, train_params['qf2'], observations, actions)
            next_actions, _ = forward_f(self.policy, train_params['policy'], next_observations)
            _, next_features1 = forward_f(self.qf, train_params['qf1'], next_observations, next_actions)
            _, next_features2 = forward_f(self.qf, train_params['qf2'], next_observations, next_actions)

            if self.config.learn_g:
                g1, _ = forward_f(self.gf, train_params['gf1'], observations, actions)
                g2, _ = forward_f(self.gf, train_params['gf2'], observations, actions)
            else:
                g1, g2 = None, None

            pop_policy_loss1, reweight1 = compute_pop_policy_loss(train_params['dualf1'], features1, next_features1, dones, g1)
            pop_policy_loss2, reweight2 = compute_pop_policy_loss(train_params['dualf2'], features2, next_features2, dones, g2)
            pop_kl1 = reweight1 * jnp.log(reweight1 + self.config.numerical_eps)
            pop_kl2 = reweight2 * jnp.log(reweight2 + self.config.numerical_eps)

            if self.config.use_automatic_kl_tuning:
                log_beta = self.log_beta.apply(train_params['log_beta'])
                beta_loss1 = -log_beta * (pop_kl1 - self.config.target_kl).mean()
                beta_loss2 = -log_beta * (pop_kl2 - self.config.target_kl).mean()
                beta_loss = (beta_loss1 + beta_loss2) / 2
                loss_collection['log_beta'] = beta_loss
                beta = jnp.clip(jnp.exp(log_beta) * self.config.beta_multiplier, self.config.beta_min, self.config.beta_max)
            else:
                beta_loss = 0.0
                if self.config.beta_multiplier == np.inf:
                    beta = 1.0
                else:
                    beta = self.config.beta_multiplier

            log_items = prefix_metrics({
                'pop_policy_loss1': pop_policy_loss1.mean(),
                'pop_policy_loss2': pop_policy_loss2.mean(),
                'pop_kl1': pop_kl1.mean(),
                'pop_kl2': pop_kl2.mean(),
                'pop_beta_loss': beta_loss,
                'pop_beta': beta,
            }, 'pop')
            return beta * (pop_policy_loss1 + pop_policy_loss2).mean(), log_items

    def get_reweighting_coef(self, train_params, target_qf_params,
                             observations, actions, rewards, next_observations, dones,
                             rng, loss_collection, flags=None):
        rng_generator = JaxRNG(rng)
        @wrap_function_with_rng(rng_generator())
        def forward_f(rng, f, *args, **kwargs):
            return f.apply(*args, **kwargs, rngs=JaxRNG(rng)(f.rng_keys()))
        
        def compute_pop_losses(dual_params, features, next_features, dones, g=None):
            m_a, m_b, a, b, a_mag, b_mag = forward_f(self.dualf, dual_params, features)
            next_m_a, _, _, _, _, _ = forward_f(self.dualf, dual_params, next_features)
            feature_term = (m_a ** 2 + m_b ** 2).sum(-1)
            angle_term = (m_b * next_m_a).sum(-1)
            g_tilde = angle_term / (a_mag * b_mag + self.config.numerical_eps)
            if g is None:
                g = g_tilde
            margin_term = self.config.pop_margin * (a ** 2 + b ** 2).sum()
            if self.config.beta_multiplier == np.inf:
                reweight = jnp.ones_like(dones)
            else:
                reweight = jax.lax.stop_gradient(jnp.exp(feature_term + 2 * a_mag * b_mag * (1. - dones) * self.config.pop_gamma * g - margin_term))
            dual_loss = (reweight * (feature_term + 2 * (1. - dones) * self.config.pop_gamma * angle_term - margin_term)).mean()
            return dual_loss, g_tilde, reweight, a_mag, b_mag
        
        if flags['train_pop']:
            _, features1 = forward_f(self.qf, train_params['qf1'], observations, actions)
            _, features2 = forward_f(self.qf, train_params['qf2'], observations, actions)
            next_actions, _ = forward_f(self.policy, train_params['policy'], next_observations)
            _, next_features1 = forward_f(self.qf, train_params['qf1'], next_observations, next_actions)
            _, next_features2 = forward_f(self.qf, train_params['qf2'], next_observations, next_actions)
            
            if self.config.learn_g:
                g1, _ = forward_f(self.gf, train_params['gf1'], observations, actions)
                g2, _ = forward_f(self.gf, train_params['gf2'], observations, actions)
                dual_loss1, g_tilde1, reweight1, a_norm1, b_norm1 = \
                    compute_pop_losses(train_params['dualf1'], features1, next_features1, dones, g1)
                dual_loss2, g_tilde2, reweight2, a_norm2, b_norm2 = \
                    compute_pop_losses(train_params['dualf2'], features2, next_features2, dones, g2)
                g1_loss = jnp.mean((g1 - jax.lax.stop_gradient(g_tilde1)) ** 2)
                g2_loss = jnp.mean((g2 - jax.lax.stop_gradient(g_tilde2)) ** 2)
                loss_collection['gf1'], loss_collection['gf2'] = g1_loss, g2_loss
            else:
                dual_loss1, g_tilde1, reweight1, a_norm1, b_norm1 = \
                    compute_pop_losses(train_params['dualf1'], features1, next_features1, dones)
                dual_loss2, g_tilde2, reweight2, a_norm2, b_norm2 = \
                    compute_pop_losses(train_params['dualf2'], features2, next_features2, dones)
                g1_loss, g2_loss = 0, 0
            
            pop_objective1 = reweight1.mean()
            pop_objective2 = reweight2.mean()
            reweight = self.reweighting_fuser(reweight1, reweight2)
            reweighting_coef = jax.lax.stop_gradient(reweight)
            
            loss_collection['dualf1'], loss_collection['dualf2'] = dual_loss1, dual_loss2
            
            log_items = prefix_metrics({
                'g_tilde1': g_tilde1.mean(),
                'g_tilde2': g_tilde2.mean(),
                'dual_loss1': dual_loss1,
                'dual_loss2': dual_loss2,
                'g1_loss': g1_loss,
                'g2_loss': g2_loss,
                'batch_pop_objective1': pop_objective1.mean(),
                'batch_pop_objective2': pop_objective2.mean(),
                'a_norm1': a_norm1.mean(),
                'b_norm1': b_norm1.mean(),
                'a_norm2': a_norm2.mean(),
                'b_norm2': b_norm2.mean(),
            }, 'pop')
            
            if self.config.backprop_dual_grad:
                raise NotImplementedError()
        else:
            reweighting_coef = 1
            if self.config.learn_g:
                loss_collection['gf1'], loss_collection['gf1'] = 0, 0
            
            log_items = {}
        
        return reweighting_coef, log_items
