from copy import deepcopy
from functools import partial

from ml_collections import ConfigDict

import numpy as np
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax

from .jax_utils import (
    next_rng, value_and_multi_grad, mse_loss, JaxRNG, wrap_function_with_rng,
    collect_jax_metrics, batch_to_jax, no_nans
)
from .model import Scalar, update_target_network
from .offline_sac import OfflineSAC
from .utils import prefix_metrics


class POPSAC2(OfflineSAC):
    
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        
        config.dual_rank = 4  # the rank of A and B
        config.pop_gamma = 1.0  # the discount factor for the pop objective
        config.pop_margin = 0.0  # the margin on the pop semi-definite constraint
        config.learn_g = True  # learn the g function as opposed to sampling it
        config.g_lr_gain = 1e0  # gain for the learning rate of g w.r.t the critic
        config.dual_lr_gain = 1e0  # gain for the learning rate of l w.r.t the critic
        config.g_arch = '1024-1024-1024-1024'
        config.dual_min_val = np.nan  # the minimum value of the dual variables A and B
        config.dual_max_val = np.nan  # the maximum value of the dual variables A and B
        config.dual_grad_clip = -1.0  # clip the gradient of the dual variables
        config.pop_start_epoch = -1  # the number of epochs before starting to learn pop coef
        config.backprop_dual_grad = False  # backprop the gradient of the dual variables
        config.beta_multiplier = 1.0
        config.use_automatic_kl_tuning = False
        config.target_kl = 1.0
        config.beta_lr = 1e-2
        
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config
    
    def __init__(self, config, policy, qf, dualf, gf):
        super().__init__(config, policy, qf)
        self.dualf = dualf
        self.gf = gf
        
        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]
        
        lf_lr = self.config.dual_lr_gain * self.config.qf_lr
        
        tx_chain = []
        if self.config.dual_grad_clip > 0:
            tx_chain.append(optax.clip_by_global_norm(self.config.dual_grad_clip))
        if self.config.optimizer_type == 'adam':
            tx_chain.append(optax.scale_by_adam())
        else:
            tx_chain.append(optax.identity())
        tx_chain.append(optax.scale(-lf_lr))
        
        dualf1_params = self.dualf.init(
            next_rng(self.dualf.rng_keys()),
            jnp.zeros((10, self.qf.feature_dim())),
        )
        self._train_states['dualf1'] = TrainState.create(
            params=dualf1_params,
            tx=optax.chain(*tx_chain),
            apply_fn=None,
        )
        dualf2_params = self.dualf.init(
            next_rng(self.dualf.rng_keys()),
            jnp.zeros((10, self.qf.feature_dim())),
        )
        self._train_states['dualf2'] = TrainState.create(
            params=dualf2_params,
            tx=optax.chain(*tx_chain),
            apply_fn=None,
        )
        self._target_qf_params.update(deepcopy({'dualf1': dualf1_params, 'dualf2': dualf2_params}))
        self._model_keys += ('dualf1', 'dualf2')
        
        policy_params = self.policy.init(
            next_rng(self.policy.rng_keys()),
            jnp.zeros((10, self.observation_dim))
        )
        self._train_states['policy_tilde'] = TrainState.create(
            params=policy_params,
            tx=optimizer_class(self.config.policy_lr),
            apply_fn=None
        )
        self._model_keys += ('policy_tilde',)
        
        if self.config.learn_g:
            gf_lr = self.config.g_lr_gain * self.config.qf_lr
            
            gf1_params = self.gf.init(
                next_rng(self.gf.rng_keys()),
                jnp.zeros((10, self.observation_dim)),
                jnp.zeros((10, self.action_dim))
            )
            self._train_states['gf1'] = TrainState.create(
                params=gf1_params,
                tx=optimizer_class(gf_lr),
                apply_fn=None,
            )
            gf2_params = self.gf.init(
                next_rng(self.gf.rng_keys()),
                jnp.zeros((10, self.observation_dim)),
                jnp.zeros((10, self.action_dim))
            )
            self._train_states['gf2'] = TrainState.create(
                params=gf2_params,
                tx=optimizer_class(gf_lr),
                apply_fn=None,
            )
            self._target_qf_params.update(deepcopy({'gf1': gf1_params, 'gf2': gf2_params}))
            self._model_keys += ('gf1', 'gf2')
        
        if self.config.use_automatic_kl_tuning:
            if self.config.beta_multiplier == 0.0:
                raise ValueError('To use automatic KL tuning, beta_multiplier must be non-zero')
            
            self.log_beta = Scalar(0.0)
            self._train_states['log_beta'] = TrainState.create(
                params=self.log_beta.init(next_rng()),
                tx=optimizer_class(self.config.beta_lr),
                apply_fn=None
            )
            self._model_keys.append('log_beta')
            
    def get_additional_training_flags(self, epoch):
        return {
            'train_pop': epoch > self.config.pop_start_epoch,
        }
    
    @partial(jax.jit, static_argnames=('self', 'flags'))
    def _train_step(self, train_states, target_qf_params, rng, batch, train_step, flags):
        rng_generator = JaxRNG(rng)
        
        def loss_fn(train_params):
            observations = batch['observations']
            actions = batch['actions']
            rewards = batch['rewards']
            next_observations = batch['next_observations']
            dones = batch['dones']
            
            loss_collection = {}
            log_items = {}
            
            @wrap_function_with_rng(rng_generator())
            def forward_f(rng, f, *args, **kwargs):
                return f.apply(*args, **kwargs, rngs=JaxRNG(rng)(f.rng_keys()))
            
            new_actions, log_pi = forward_f(self.policy, train_params['policy'], observations)
            
            if self.config.use_automatic_entropy_tuning:
                alpha_loss = -self.log_alpha.apply(train_params['log_alpha']) * (
                        log_pi + self.config.target_entropy).mean()
                loss_collection['log_alpha'] = alpha_loss
                alpha = jnp.exp(self.log_alpha.apply(train_params['log_alpha'])) * self.config.alpha_multiplier
            else:
                alpha_loss = 0.0
                alpha = self.config.alpha_multiplier
            
            """ Policy loss """
            if flags['bc']:
                log_probs = forward_f(self.policy, train_params['policy'], observations, actions,
                                      method=self.policy.log_prob)
                policy_loss = (alpha * log_pi - log_probs).mean()
            else:
                q1, _ = forward_f(self.qf, train_params['qf1'], observations, new_actions)
                q2, _ = forward_f(self.qf, train_params['qf2'], observations, new_actions)
                q_new_actions = jnp.minimum(q1, q2)
                policy_loss = (alpha * log_pi - q_new_actions).mean()
            
            additional_policy_loss, additional_policy_loss_log_items = self.additional_policy_loss(
                train_params, target_qf_params,
                observations, actions, rewards, next_observations, dones,
                rng, loss_collection, flags
            )
            policy_loss += additional_policy_loss
            log_items.update(additional_policy_loss_log_items)
            
            loss_collection['policy'] = policy_loss
            
            log_items.update(prefix_metrics({
                'log_pi': log_pi.mean(),
                'policy_loss': policy_loss.mean(),
                'alpha_loss': alpha_loss.mean(),
                'alpha': alpha,
            }, 'sac'))
            
            """ Policy Projection """
            # def compute_pop_policy_loss(dual_params, features, next_features, dones, g=None):
            #     m_a, m_b, a, b, a_mag, b_mag = forward_f(self.dualf, dual_params, features)
            #     next_m_a, _, _, _, _, _ = forward_f(self.dualf, dual_params, next_features)
            #     feature_term = (m_a ** 2 + m_b ** 2).sum(-1)
            #     angle_term = self.config.pop_gamma * (m_b * next_m_a).sum(-1)
            #     if g is None:
            #         g_tilde = angle_term / (a_mag * b_mag + self.config.numerical_eps)
            #         g = g_tilde
            #     margin_term = self.config.pop_margin * (a ** 2 + b ** 2).sum()
            #     reweight = jnp.exp(feature_term + 2 * a_mag * b_mag * (1. - dones) * g - margin_term)
            #     reweight = reweight / (reweight.mean() + self.config.numerical_eps)
            #     reweight = jnp.clip(reweight, self.config.reweighting_coef_min, self.config.reweighting_coef_max)
            #     pop_policy_loss = -jax.lax.stop_gradient(reweight) * (1. - dones) * angle_term
            #     return pop_policy_loss, reweight
            
            def compute_pop_policy_loss(dual_params, features, next_features, dones, g=None):
                m_a, m_b, a, b, a_mag, b_mag = forward_f(self.dualf, dual_params, features)
                next_m_a, _, _, _, _, _ = forward_f(self.dualf, dual_params, next_features)
                next_m_a = (1. - dones)[:, None] * next_m_a + dones[:, None] * m_a
                feature_term = (m_a ** 2 + m_b ** 2).sum(-1)
                angle_term = self.config.pop_gamma * (m_b * next_m_a).sum(-1)
                if g is None:
                    g_tilde = angle_term / (a_mag * b_mag + self.config.numerical_eps)
                    g = g_tilde
                margin_term = self.config.pop_margin * (a ** 2 + b ** 2).sum()
                reweight = jnp.exp(feature_term + 2 * a_mag * b_mag * g - margin_term)
                reweight = reweight / (reweight.mean() + self.config.numerical_eps)
                reweight = jnp.clip(reweight, self.config.reweighting_coef_min, self.config.reweighting_coef_max)
                pop_policy_loss = -jax.lax.stop_gradient(reweight) * angle_term
                return pop_policy_loss, reweight
            
            next_actions, _ = forward_f(self.policy, train_params['policy_tilde'], next_observations)
            _, features1 = forward_f(self.qf, train_params['qf1'], observations, actions)
            _, features2 = forward_f(self.qf, train_params['qf2'], observations, actions)
            _, next_features1 = forward_f(self.qf, train_params['qf1'], next_observations, next_actions)
            _, next_features2 = forward_f(self.qf, train_params['qf2'], next_observations, next_actions)
            
            if self.config.learn_g:
                g1, _ = forward_f(self.gf, train_params['gf1'], observations, actions)
                g2, _ = forward_f(self.gf, train_params['gf2'], observations, actions)
            else:
                g1, g2 = None, None
            
            pop_policy_loss1, reweight1 = compute_pop_policy_loss(train_params['dualf1'], features1, next_features1,
                                                                  dones, g1)
            pop_policy_loss2, reweight2 = compute_pop_policy_loss(train_params['dualf2'], features2, next_features2,
                                                                  dones, g2)
            pop_kl1 = reweight1 * jnp.log(reweight1 + self.config.numerical_eps)
            pop_kl2 = reweight2 * jnp.log(reweight2 + self.config.numerical_eps)
            beta = self.config.beta_multiplier
            
            policy_dist = forward_f(self.policy, train_params['policy'], observations, method=self.policy.dist)
            policy_dist_tilde = forward_f(self.policy, train_params['policy_tilde'], observations, method=self.policy.dist)
            policy_kl = policy_dist.kl_divergence(policy_dist_tilde)
            policy_tilde_loss = policy_kl.mean() + beta * (pop_policy_loss1 + pop_policy_loss2).mean()
            
            log_items.update(prefix_metrics({
                'pop_policy_loss1': pop_policy_loss1.mean(),
                'pop_policy_loss2': pop_policy_loss2.mean(),
                'policy_kl': policy_kl.mean(),
                'pop_kl1': pop_kl1.mean(),
                'pop_kl2': pop_kl2.mean(),
                'pop_beta': beta,
            }, 'pop'))
            loss_collection['policy_tilde'] = policy_tilde_loss
            
            """ Q function loss """
            if not self.config.fixed_qf:
                q1_loss, q2_loss = 0, 0
                q1, features1 = forward_f(self.qf, train_params['qf1'], observations, actions)
                q2, features2 = forward_f(self.qf, train_params['qf2'], observations, actions)
                
                new_next_actions, next_log_pi = forward_f(
                    self.policy, train_params['policy_tilde'], next_observations, repeat=self.config.n_next_action_samples
                )
                target_next_q1, _ = \
                    forward_f(self.qf, target_qf_params['qf1'], next_observations, new_next_actions)
                target_next_q2, _ = \
                    forward_f(self.qf, target_qf_params['qf2'], next_observations, new_next_actions)
                target_next_q = jnp.minimum(target_next_q1, target_next_q2)
                max_target_indices = jnp.expand_dims(jnp.argmax(target_next_q, axis=-1), axis=-1)
                target_next_q = jnp.take_along_axis(target_next_q, max_target_indices, axis=-1).squeeze(-1)
                next_log_pi = jnp.take_along_axis(next_log_pi, max_target_indices, axis=-1).squeeze(-1)
                
                """ Reweighting """
                reweighting_coef, reweighting_coef_log_items = self.get_reweighting_coef(
                    train_params, target_qf_params,
                    observations, actions, rewards, next_observations, dones,
                    rng, loss_collection, flags
                )
                if reweighting_coef is None:
                    reweighting_coef = jnp.ones_like(q1)
                else:
                    log_items.update(reweighting_coef_log_items)
                    reweighting_coef = self.normalize_reweighting_coef(reweighting_coef)
                
                if self.config.backup_entropy:
                    target_next_q = target_next_q - alpha * next_log_pi
                
                target_q_values = target_next_q
                td_target = jax.lax.stop_gradient(
                    rewards + (1. - dones) * self.config.discount * target_next_q
                )
                q1_loss += jnp.mean(reweighting_coef[0] * jnp.square(q1 - td_target))
                q2_loss += jnp.mean(reweighting_coef[1] * jnp.square(q2 - td_target))
                
                additional_q1_loss, additional_q2_loss, additional_q_loss_log_items = self.additional_q_loss(
                    train_params, target_qf_params,
                    q1, q2,
                    observations, actions, rewards, next_observations, dones,
                    rng, loss_collection, flags
                )
                q1_loss += additional_q1_loss
                q2_loss += additional_q2_loss
                
                loss_collection['qf1'] = q1_loss
                loss_collection['qf2'] = q2_loss
                
                log_items.update(prefix_metrics({
                    'qf1_loss': q1_loss.mean(),
                    'qf2_loss': q2_loss.mean(),
                    'q1': q1.mean(),
                    'q2': q2.mean(),
                    'target_q_values': target_q_values.mean(),
                }, prefix='sac'))
                log_items.update(prefix_metrics({
                    'reweighting_coef_mean': reweighting_coef.mean(),
                    'reweighting_coef_std': reweighting_coef.std(),
                }, prefix='reweighting'))
            
            return tuple(loss_collection[key] for key in self.model_keys), log_items
        
        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, log_items), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params)
        
        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }
        new_target_qf_params = {}
        target_model_keys = self._target_qf_params.keys()
        
        # for i, key in enumerate(self.model_keys):
        #     if key == 'policy_tilde':
        #         model_grads = grads[i][key]
        #         max_dense_0_kernel = jnp.max(jnp.abs(model_grads['params']['base_network']['Dense_0']['kernel']))
        #         max_dense_0_bias = jnp.max(jnp.abs(model_grads['params']['base_network']['Dense_0']['bias']))
        #         max_dense_1_kernel = jnp.max(jnp.abs(model_grads['params']['base_network']['Dense_1']['kernel']))
        #         max_dense_1_bias = jnp.max(jnp.abs(model_grads['params']['base_network']['Dense_1']['bias']))
        #         max_dense_2_kernel = jnp.max(jnp.abs(model_grads['params']['base_network']['Dense_2']['kernel']))
        #         max_dense_2_bias = jnp.max(jnp.abs(model_grads['params']['base_network']['Dense_2']['bias']))
        #         log_std_multiplier_module = jnp.max(jnp.abs(model_grads['params']['log_std_multiplier_module']['value']))
        #         log_std_offset_module = jnp.max(jnp.abs(model_grads['params']['log_std_offset_module']['value']))
        #         log_items.update(prefix_metrics({
        #             'max_dense_0_kernel': max_dense_0_kernel,
        #             'max_dense_0_bias': max_dense_0_bias,
        #             'max_dense_1_kernel': max_dense_1_kernel,
        #             'max_dense_1_bias': max_dense_1_bias,
        #             'max_dense_2_kernel': max_dense_2_kernel,
        #             'max_dense_2_bias': max_dense_2_bias,
        #             'log_std_multiplier_module': log_std_multiplier_module,
        #             'log_std_offset_module': log_std_offset_module,
        #         }, prefix='pop'))
        
        for key in target_model_keys:
            new_target_qf_params[key] = update_target_network(
                new_train_states[key].params, target_qf_params[key],
                self.config.soft_target_update_rate * (train_step % self.config.target_update_freq == 0)
            )
        
        return new_train_states, new_target_qf_params, log_items
    
    def get_reweighting_coef(self, train_params, target_qf_params,
                             observations, actions, rewards, next_observations, dones,
                             rng, loss_collection, flags=None):
        rng_generator = JaxRNG(rng)
        @wrap_function_with_rng(rng_generator())
        def forward_f(rng, f, *args, **kwargs):
            return f.apply(*args, **kwargs, rngs=JaxRNG(rng)(f.rng_keys()))
        
        # def compute_pop_losses(dual_params, features, next_features, dones, g=None):
        #     m_a, m_b, a, b, a_mag, b_mag = forward_f(self.dualf, dual_params, features)
        #     next_m_a, _, _, _, _, _ = forward_f(self.dualf, dual_params, next_features)
        #     feature_term = (m_a ** 2 + m_b ** 2).sum(-1)
        #     angle_term = (m_b * next_m_a).sum(-1)
        #     g_tilde = angle_term / (a_mag * b_mag + self.config.numerical_eps)
        #     if g is None:
        #         g = g_tilde
        #     margin_term = self.config.pop_margin * (a ** 2 + b ** 2).sum()
        #     reweight = jax.lax.stop_gradient(jnp.exp(feature_term + 2 * a_mag * b_mag * (1. - dones) * g - margin_term))
        #     dual_loss = (reweight * (feature_term + 2 * (1. - dones) * angle_term - margin_term)).mean()
        #     return dual_loss, g_tilde, reweight, a_mag, b_mag
        
        def compute_pop_losses(dual_params, features, next_features, dones, g=None):
            m_a, m_b, a, b, a_mag, b_mag = forward_f(self.dualf, dual_params, features)
            next_m_a, _, _, _, _, _ = forward_f(self.dualf, dual_params, next_features)
            next_m_a = (1. - dones)[:, None] * next_m_a + dones[:, None] * m_a
            feature_term = (m_a ** 2 + m_b ** 2).sum(-1)
            angle_term = self.config.pop_gamma * (m_b * next_m_a).sum(-1)
            g_tilde = angle_term / (a_mag * b_mag + self.config.numerical_eps)
            if g is None:
                g = g_tilde
            margin_term = self.config.pop_margin * (a ** 2 + b ** 2).sum()
            reweight = jax.lax.stop_gradient(jnp.exp(feature_term + 2 * a_mag * b_mag * g - margin_term))
            dual_loss = (reweight * (feature_term + 2 * angle_term - margin_term)).mean()
            return dual_loss, g_tilde, reweight, a_mag, b_mag
        
        if flags['train_pop']:
            # _, target_features1 = forward_f(self.qf, target_qf_params['qf1'], observations, actions)
            # _, target_features2 = forward_f(self.qf, target_qf_params['qf2'], observations, actions)
            # target_next_actions, _ = forward_f(self.policy, target_qf_params['policy'], next_observations)
            # _, target_next_features1 = \
            #     forward_f(self.qf, target_qf_params['qf1'], next_observations, target_next_actions)
            # _, target_next_features2 = \
            #     forward_f(self.qf, target_qf_params['qf2'], next_observations, target_next_actions)
            _, features1 = forward_f(self.qf, train_params['qf1'], observations, actions)
            _, features2 = forward_f(self.qf, train_params['qf2'], observations, actions)
            next_actions, _ = forward_f(self.policy, train_params['policy_tilde'], next_observations)
            _, next_features1 = forward_f(self.qf, train_params['qf1'], next_observations, next_actions)
            _, next_features2 = forward_f(self.qf, train_params['qf2'], next_observations, next_actions)
            
            if self.config.learn_g:
                g1, _ = forward_f(self.gf, train_params['gf1'], observations, actions)
                g2, _ = forward_f(self.gf, train_params['gf2'], observations, actions)
                # dual_loss1, g_tilde1, reweight1, a_norm1, b_norm1 = \
                #     compute_pop_losses(train_params['dualf1'], target_features1, target_next_features1, dones, g1)
                # dual_loss2, g_tilde2, reweight2, a_norm2, b_norm2 = \
                #     compute_pop_losses(train_params['dualf2'], target_features2, target_next_features2, dones, g2)
                dual_loss1, g_tilde1, reweight1, a_norm1, b_norm1 = \
                    compute_pop_losses(train_params['dualf1'], features1, next_features1, dones, g1)
                dual_loss2, g_tilde2, reweight2, a_norm2, b_norm2 = \
                    compute_pop_losses(train_params['dualf2'], features2, next_features2, dones, g2)
                # g1_loss = jnp.mean((1. - dones) * (g1 - jax.lax.stop_gradient(g_tilde1)) ** 2)
                # g2_loss = jnp.mean((1. - dones) * (g2 - jax.lax.stop_gradient(g_tilde2)) ** 2)
                g1_loss = jnp.mean((g1 - jax.lax.stop_gradient(g_tilde1)) ** 2)
                g2_loss = jnp.mean((g2 - jax.lax.stop_gradient(g_tilde2)) ** 2)
                loss_collection['gf1'], loss_collection['gf2'] = g1_loss, g2_loss
            else:
                # dual_loss1, g_tilde1, reweight1, a_norm1, b_norm1 = \
                #     compute_pop_losses(train_params['dualf1'], target_features1, target_next_features1, dones)
                # dual_loss2, g_tilde2, reweight2, a_norm2, b_norm2 = \
                #     compute_pop_losses(train_params['dualf2'], target_features2, target_next_features2, dones)
                pass
            
            pop_objective1 = reweight1.mean()
            pop_objective2 = reweight2.mean()
            reweight = self.reweighting_fuser(reweight1, reweight2)
            reweighting_coef = jax.lax.stop_gradient(reweight)
            
            loss_collection['dualf1'], loss_collection['dualf2'] = dual_loss1, dual_loss2
            
            log_items = prefix_metrics({
                'g_tilde1': g_tilde1.mean(),
                'g_tilde2': g_tilde2.mean(),
                'dual_loss1': dual_loss1.mean(),
                'dual_loss2': dual_loss2.mean(),
                'g1_loss': g1_loss.mean(),
                'g2_loss': g2_loss.mean(),
                'batch_pop_objective1': pop_objective1.mean(),
                'batch_pop_objective2': pop_objective2.mean(),
                'a_norm1': a_norm1.mean(),
                'b_norm1': b_norm1.mean(),
                'a_norm2': a_norm2.mean(),
                'b_norm2': b_norm2.mean(),
            }, 'pop')
            
            if self.config.backprop_dual_grad:
                raise NotImplementedError()
        else:
            reweighting_coef = 1
            if self.config.learn_g:
                loss_collection['gf1'], loss_collection['gf1'] = 0, 0
            
            log_items = {}
        
        return reweighting_coef, log_items
