from collections import OrderedDict
from copy import deepcopy
from functools import partial

from ml_collections import ConfigDict

import numpy as np
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax
import distrax

from .jax_utils import (
    next_rng, value_and_multi_grad, mse_loss, JaxRNG, wrap_function_with_rng,
    collect_jax_metrics, batch_to_jax, no_nans
)
from .model import Scalar, update_target_network
from .utils import prefix_metrics
from .replay_buffer import index_batch

# Mean is written out because torch.mean does not work between multiple inputs:
FUSERS = {'none': lambda a, b: jnp.stack([a, b]), "mean": lambda a, b: (a + b) / 2, "max": jnp.maximum,
          "min": jnp.minimum}


class OfflineSAC(object):
    
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.fixed_policy = False  # use a fixed policy
        config.fixed_qf = False  # use a fixed q function
        config.discount = 0.99
        config.alpha_multiplier = 1.0
        config.use_automatic_entropy_tuning = True
        config.backup_entropy = False
        config.target_entropy = 0.0
        config.policy_lr = 3e-4
        config.qf_lr = 3e-4
        config.n_next_action_samples = 1
        config.optimizer_type = 'adam'
        config.target_update_freq = 1
        config.soft_target_update_rate = 5e-3
        config.bc_epochs = 0  # number of epochs to train the policy with behavioral cloning
        config.normalize_reweighting_coef = False  # keep the batch mean of the reweighting coefs at 1
        config.reweighting_fuser = 'none'
        config.reweighting_coef_min = -np.inf  # clip the reweighting coefs to be between min and max
        config.reweighting_coef_max = np.inf
        config.reweight_cql_loss = False  # use the reweighting coefs in the CQL loss
        config.numerical_eps = 1e-8  # epsilon for numerical stability
        
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config
    
    def __init__(self, config, policy, qf):
        self.config = self.get_default_config(config)
        self.policy = policy
        self.qf = qf
        self.observation_dim = policy.observation_dim
        self.action_dim = policy.action_dim
        
        self._train_states = {}
        
        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]
        
        policy_params = self.policy.init(
            next_rng(self.policy.rng_keys()),
            jnp.zeros((10, self.observation_dim))
        )
        assert no_nans(policy_params['params']['base_network'])
        
        self._train_states['policy'] = TrainState.create(
            params=policy_params,
            tx=optimizer_class(self.config.policy_lr),
            apply_fn=None
        )
        assert no_nans(self._train_states['policy'].params['params']['base_network'])
        
        qf1_params = self.qf.init(
            next_rng(self.qf.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self._train_states['qf1'] = TrainState.create(
            params=qf1_params,
            tx=optimizer_class(self.config.qf_lr),
            apply_fn=None,
        )
        qf2_params = self.qf.init(
            next_rng(self.qf.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self._train_states['qf2'] = TrainState.create(
            params=qf2_params,
            tx=optimizer_class(self.config.qf_lr),
            apply_fn=None,
        )
        self._target_qf_params = deepcopy({'qf1': qf1_params, 'qf2': qf2_params})
        self._target_qf_params.update(deepcopy({'policy': policy_params}))
        
        self.reweighting_fuser = FUSERS[self.config.reweighting_fuser]
        
        model_keys = ['policy', 'qf1', 'qf2']
        
        if self.config.use_automatic_entropy_tuning:
            self.log_alpha = Scalar(0.0)
            self._train_states['log_alpha'] = TrainState.create(
                params=self.log_alpha.init(next_rng()),
                tx=optimizer_class(self.config.policy_lr),
                apply_fn=None
            )
            model_keys.append('log_alpha')
        
        self._model_keys = tuple(model_keys)
        self._total_steps = 0
    
    def get_additional_training_flags(self, epoch):
        return {}
    
    def train(self, batch, epoch):
        self._total_steps += 1
        flags = {'bc': epoch < self.config.bc_epochs}
        flags.update(self.get_additional_training_flags(epoch))
        self._train_states, self._target_qf_params, metrics = self._train_step(
            self._train_states, self._target_qf_params, next_rng(), batch, self._total_steps,
            flax.core.FrozenDict(flags)
        )
        return metrics
    
    def additional_policy_loss(self, train_params, target_qf_params,
                               observations, actions, rewards, next_observations, dones,
                               rng, loss_collection, flags=None):
        return 0.0, {}
    
    def additional_q_loss(self, train_params, target_qf_params,
                          q1, q2,
                          observations, actions, rewards, next_observations, dones,
                          rng, loss_collection, flags=None):
        return 0.0, 0.0, {}
    
    def get_reweighting_coef(self, train_params, target_qf_params,
                             observations, actions, rewards, next_observations, dones,
                             rng, loss_collection, flags=None):
        return None, {}
    
    def normalize_reweighting_coef(self, reweighting_coef):
        reweighting_coef_mean = jnp.mean(reweighting_coef, axis=1, keepdims=True)
        reweighting_coef_std = jnp.std(reweighting_coef, axis=1, keepdims=True)
        if self.config.normalize_reweighting_coef:
            reweighting_coef = ((reweighting_coef - reweighting_coef_mean) / (
                    reweighting_coef_std + self.config.numerical_eps)) + 1
        else:
            reweighting_coef = reweighting_coef / (reweighting_coef_mean + self.config.numerical_eps)
        reweighting_coef = np.clip(reweighting_coef, self.config.reweighting_coef_min,
                                   self.config.reweighting_coef_max)
        return reweighting_coef
    
    @partial(jax.jit, static_argnames=('self', 'flags'))
    def _train_step(self, train_states, target_qf_params, rng, batch, train_step, flags):
        rng_generator = JaxRNG(rng)
        
        def loss_fn(train_params):
            observations = batch['observations']
            actions = batch['actions']
            rewards = batch['rewards']
            next_observations = batch['next_observations']
            dones = batch['dones']
            
            loss_collection = {}
            log_items = {}
            
            @wrap_function_with_rng(rng_generator())
            def forward_f(rng, f, *args, **kwargs):
                return f.apply(*args, **kwargs, rngs=JaxRNG(rng)(f.rng_keys()))
            
            new_actions, log_pi = forward_f(self.policy, train_params['policy'], observations)
            
            if self.config.use_automatic_entropy_tuning:
                alpha_loss = -self.log_alpha.apply(train_params['log_alpha']) * (
                        log_pi + self.config.target_entropy).mean()
                loss_collection['log_alpha'] = alpha_loss
                alpha = jnp.exp(self.log_alpha.apply(train_params['log_alpha'])) * self.config.alpha_multiplier
            else:
                alpha_loss = 0.0
                alpha = self.config.alpha_multiplier
            
            """ Policy loss """
            if not self.config.fixed_policy:
                if flags['bc']:
                    log_probs = forward_f(self.policy, train_params['policy'], observations, actions,
                                          method=self.policy.log_prob)
                    policy_loss = (alpha * log_pi - log_probs).mean()
                else:
                    q1, _ = forward_f(self.qf, train_params['qf1'], observations, new_actions)
                    q2, _ = forward_f(self.qf, train_params['qf2'], observations, new_actions)
                    q_new_actions = jnp.minimum(q1, q2)
                    policy_loss = (alpha * log_pi - q_new_actions).mean()
                
                additional_policy_loss, additional_policy_loss_log_items = self.additional_policy_loss(
                    train_params, target_qf_params,
                    observations, actions, rewards, next_observations, dones,
                    rng, loss_collection, flags
                )
                policy_loss += additional_policy_loss
                log_items.update(additional_policy_loss_log_items)
                
                loss_collection['policy'] = policy_loss
                
                log_items.update(prefix_metrics({
                    'log_pi': log_pi.mean(),
                    'policy_loss': policy_loss.mean(),
                    'alpha_loss': alpha_loss.mean(),
                    'alpha': alpha,
                }, 'sac'))
            else:
                loss_collection['policy'] = 0.0
            
            """ Q function loss """
            if not self.config.fixed_qf:
                q1_loss, q2_loss = 0, 0
                q1, features1 = forward_f(self.qf, train_params['qf1'], observations, actions)
                q2, features2 = forward_f(self.qf, train_params['qf2'], observations, actions)
                
                new_next_actions, next_log_pi = forward_f(
                    self.policy, train_params['policy'], next_observations, repeat=self.config.n_next_action_samples
                )
                target_next_q1, _ = \
                    forward_f(self.qf, target_qf_params['qf1'], next_observations, new_next_actions)
                target_next_q2, _ = \
                    forward_f(self.qf, target_qf_params['qf2'], next_observations, new_next_actions)
                target_next_q = jnp.minimum(target_next_q1, target_next_q2)
                max_target_indices = jnp.expand_dims(jnp.argmax(target_next_q, axis=-1), axis=-1)
                target_next_q = jnp.take_along_axis(target_next_q, max_target_indices, axis=-1).squeeze(-1)
                next_log_pi = jnp.take_along_axis(next_log_pi, max_target_indices, axis=-1).squeeze(-1)
                
                """ Reweighting """
                reweighting_coef, reweighting_coef_log_items = self.get_reweighting_coef(
                    train_params, target_qf_params,
                    observations, actions, rewards, next_observations, dones,
                    rng, loss_collection, flags
                )
                if reweighting_coef is None:
                    reweighting_coef = jnp.ones_like(q1)
                else:
                    log_items.update(reweighting_coef_log_items)
                    reweighting_coef = self.normalize_reweighting_coef(reweighting_coef)
                
                if self.config.backup_entropy:
                    target_next_q = target_next_q - alpha * next_log_pi
                
                target_q_values = target_next_q
                td_target = jax.lax.stop_gradient(
                    rewards + (1. - dones) * self.config.discount * target_next_q
                )
                q1_loss += jnp.mean(reweighting_coef[0] * jnp.square(q1 - td_target))
                q2_loss += jnp.mean(reweighting_coef[1] * jnp.square(q2 - td_target))
                
                additional_q1_loss, additional_q2_loss, additional_q_loss_log_items = self.additional_q_loss(
                    train_params, target_qf_params,
                    q1, q2,
                    observations, actions, rewards, next_observations, dones,
                    rng, loss_collection, flags
                )
                log_items.update(additional_q_loss_log_items)
                q1_loss += additional_q1_loss
                q2_loss += additional_q2_loss
                
                loss_collection['qf1'] = q1_loss
                loss_collection['qf2'] = q2_loss
                
                log_items.update(prefix_metrics({
                    'qf1_loss': q1_loss.mean(),
                    'qf2_loss': q2_loss.mean(),
                    'q1': q1.mean(),
                    'q2': q2.mean(),
                    'target_q_values': target_q_values.mean(),
                }, prefix='sac'))
                log_items.update(prefix_metrics({
                    'reweighting_coef_mean': reweighting_coef.mean(),
                    'reweighting_coef_std': reweighting_coef.std(),
                }, prefix='reweighting'))
            else:
                loss_collection['qf1'] = 0.0
                loss_collection['qf2'] = 0.0
            
            return tuple(loss_collection[key] for key in self.model_keys), log_items
        
        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, log_items), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params)
        
        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }
        new_target_qf_params = {}
        target_model_keys = self._target_qf_params.keys()
        
        for key in target_model_keys:
            new_target_qf_params[key] = update_target_network(
                new_train_states[key].params, target_qf_params[key],
                self.config.soft_target_update_rate * (train_step % self.config.target_update_freq == 0)
            )
        
        return new_train_states, new_target_qf_params, log_items
    
    @partial(jax.jit, static_argnames=('self'))
    def _compute_F_batch(self, train_params, observations, actions, next_observations, dones, reweight=None):
        def compute_single_F_batch(phi, phi_next, dones, reweight=None):
            bs, feature_dim = phi.shape
            
            phi = phi.reshape([bs, feature_dim, 1])
            phi_next = phi_next.reshape([bs, feature_dim, 1])
            
            A = phi @ phi.transpose(0, 2, 1)
            B = (1. - dones).reshape([-1, 1, 1]) * (phi @ phi_next.transpose(0, 2, 1))
            F_s = jnp.concatenate([jnp.concatenate([A, B], axis=2),
                                   jnp.concatenate([B.transpose(0, 2, 1), A], axis=2)], axis=1)
            
            if reweight is None:
                return F_s.sum(0)
            else:
                return F_s.sum(0), (F_s * reweight.reshape([bs, 1, 1])).sum(0)
        
        new_next_actions, _ = self.policy.apply(train_params['policy'], next_observations, deterministic=True)
        _, phi1 = self.qf.apply(train_params['qf1'], observations, actions)
        _, phi_next1 = self.qf.apply(train_params['qf1'], next_observations, new_next_actions)
        _, phi2 = self.qf.apply(train_params['qf2'], observations, actions)
        _, phi_next2 = self.qf.apply(train_params['qf2'], next_observations, new_next_actions)
        
        if reweight is None:
            return compute_single_F_batch(phi1, phi_next1, dones), compute_single_F_batch(phi2, phi_next2, dones)
        else:
            F_mu1_batch, F_q1_batch = compute_single_F_batch(phi1, phi_next1, dones, reweight[0])
            F_mu2_batch, F_q2_batch = compute_single_F_batch(phi2, phi_next2, dones, reweight[1])
            return F_mu1_batch, F_mu2_batch, F_q1_batch, F_q2_batch
    
    @partial(jax.jit, static_argnames=('self', 'flags'))
    def _get_reweighting_coef_jax(self, train_params, target_params, observations, actions, rewards, next_observations,
                                   dones, rng, loss_collection, flags):
        return self.get_reweighting_coef(train_params, target_params,
                                         observations, actions, rewards, next_observations, dones,
                                         rng, loss_collection, flags)
    
    def get_F_matrices(self, dataset, epoch, batch_size=256, use_reweighting=False):
        rng = next_rng()
        train_params = {key: self._train_states[key].params for key in self.model_keys}
        target_params = self._target_qf_params
        
        n = dataset['rewards'].shape[0]
        
        F_mu1_sum = 0.0
        F_mu2_sum = 0.0
        F_q1_sum = 0.0
        F_q2_sum = 0.0
        
        if use_reweighting:
            reweighting_coef = np.ones((2, n))
            for i in range(n // batch_size + 1):
                flags = flax.core.FrozenDict(self.get_additional_training_flags(epoch))
                batch = batch_to_jax(
                    index_batch(dataset, i * batch_size + np.arange(0, min(batch_size, n - i * batch_size))))
                reweighting_coef_batch, _ =\
                    self._get_reweighting_coef_jax(train_params, target_params, **batch,
                                                   rng=rng, loss_collection={}, flags=flags)
                reweighting_coef[:, i * batch_size: (i + 1) * batch_size] = reweighting_coef_batch
            
            norm_reweighting_coef = self.normalize_reweighting_coef(reweighting_coef)
        
        for i in range(n // batch_size + 1):
            batch = batch_to_jax(
                index_batch(dataset, i * batch_size + np.arange(0, min(batch_size, n - i * batch_size))))
            observations, actions, next_observations, dones =\
                batch['observations'], batch['actions'], batch['next_observations'], batch['dones']
            
            if use_reweighting:
                F_mu1_batch, F_mu2_batch, F_q1_batch, F_q2_batch =\
                    self._compute_F_batch(train_params, observations, actions, next_observations, dones,
                                          reweight=norm_reweighting_coef[:, i * batch_size: (i + 1) * batch_size])
                F_q1_sum += F_q1_batch
                F_q2_sum += F_q2_batch
            else:
                F_mu1_batch, F_mu2_batch =\
                    self._compute_F_batch(train_params, observations, actions, next_observations, dones)
            F_mu1_sum += F_mu1_batch
            F_mu2_sum += F_mu2_batch
        
        if use_reweighting:
            return (F_mu1_sum / n, F_mu2_sum / n, F_q1_sum / n, F_q2_sum / n,
                    reweighting_coef[0].mean(), reweighting_coef[1].mean())
        else:
            return F_mu1_sum / n, F_mu2_sum / n
    
    @property
    def model_keys(self):
        return self._model_keys
    
    @property
    def train_states(self):
        return self._train_states
    
    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}
    
    @property
    def total_steps(self):
        return self._total_steps
