from copy import deepcopy
from functools import partial

from ml_collections import ConfigDict

import numpy as np
import jax
from flax.training.train_state import TrainState
import optax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from icml_supplies.contextformer.JaxCQL.jax_utils import (
    next_rng, value_and_multi_grad, mse_loss, JaxRNG, wrap_function_with_rng,
    collect_jax_metrics,batch_to_array
)
from icml_supplies.contextformer.JaxCQL.model import update_target_network


class ConservativeSAC(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.discount = 0.99
        config.alpha_multiplier = 1.0
        config.use_automatic_entropy_tuning = True
        config.backup_entropy = False
        config.target_entropy = 0.0
        config.policy_lr = 1e-4
        config.qf_lr = 3e-4
        config.optimizer_type = 'adam'
        config.soft_target_update_rate = 5e-3
        config.cql_n_actions = 10
        config.cql_importance_sample = True
        config.cql_lagrange = False
        config.cql_target_action_gap = 1.0
        config.cql_temp = 1.0
        config.cql_max_target_backup = True
        config.cql_clip_diff_min = -np.inf
        config.cql_clip_diff_max = np.inf
        config.td3_alpha = 1.5
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())

        return config

    def __init__(self, config, policy, qf):
        self.config = self.get_default_config(config)
        self.config.use_cql = False
        self.config.use_automatic_entropy_tuning = False
        self.td3_alpha = self.config.td3_alpha
        self.policy = policy
        self.qf = qf
        self.observation_dim = policy.observation_dim
        self.action_dim = policy.action_dim

        self._train_states = {}

        optimizer_class = {
            'adam': optax.adam,
            'sgd': optax.sgd,
        }[self.config.optimizer_type]

        policy_params = self.policy.init(
            next_rng(self.policy.rng_keys()),
            jnp.zeros((10, self.observation_dim))
        )

        self._train_states['policy'] = TrainState.create(
            params=policy_params,
            tx=optimizer_class(self.config.policy_lr),
            apply_fn=None
        )

        qf1_params = self.qf.init(
            next_rng(self.qf.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )

        qf2_params = self.qf.init(
            next_rng(self.qf.rng_keys()),
            jnp.zeros((10, self.observation_dim)),
            jnp.zeros((10, self.action_dim))
        )
        self.qf1_params=qf1_params
        self.qf2_params=qf2_params
        self._train_states['qf1'] = TrainState.create(
            params=qf1_params,
            tx=optimizer_class(self.config.qf_lr),
            apply_fn=None,
        )
        self._train_states['qf2'] = TrainState.create(
            params=qf2_params,
            tx=optimizer_class(self.config.qf_lr),
            apply_fn=None,
        )
        self._target_qf_params = deepcopy({'qf1': qf1_params, 'qf2': qf2_params})
        model_keys = ['policy', 'qf1', 'qf2']


        self._model_keys = tuple(model_keys)
        self._total_steps = 0
    def compute_q(self,train_params, observations,actions):
        rng=next_rng()
        rng_generator = JaxRNG(rng)
        @wrap_function_with_rng(rng_generator())
        def forward_qf(rng, *args, **kwargs):
            return self.qf.apply(
                *args, **kwargs,
                rngs=JaxRNG(rng)(self.qf.rng_keys())
            )
        param1,param2=train_params
        q_new_actions = jnp.minimum(
                forward_qf(param1, observations, actions),
                forward_qf(param2, observations, actions),
            )
        q_val=batch_to_array(q_new_actions)
        return q_val

    def train(self, batch, use_cql=True, cql_min_q_weight=5.0, enable_calql=False):
        self._total_steps += 1
        self._train_states, self._target_qf_params, metrics = self._train_step(
            self._train_states,
            self._target_qf_params,
            next_rng(),
            batch, use_cql,
            cql_min_q_weight, enable_calql
        )
        return metrics

    @partial(jax.jit, static_argnames=('self', 'use_cql', 'cql_min_q_weight', 'enable_calql'))
    def _train_step(self, train_states,
                          target_qf_params,
                          rng, batch, use_cql=True,
                          cql_min_q_weight=5.0,
                          enable_calql=False):

        rng_generator = JaxRNG(rng)
        def loss_fn(train_params):
            observations = batch['observations']
            actions = batch['actions']
            rewards = batch['rewards']
            next_observations = batch['next_observations']
            dones = batch['dones']
            loss_collection = {}
            @wrap_function_with_rng(rng_generator())
            def forward_policy(rng, *args, **kwargs):
                return self.policy.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.policy.rng_keys())
                )

            @wrap_function_with_rng(rng_generator())
            def forward_qf(rng, *args, **kwargs):
                return self.qf.apply(
                    *args, **kwargs,
                    rngs=JaxRNG(rng)(self.qf.rng_keys())
                )
            new_actions, log_pi = forward_policy(train_params['policy'], observations)
            """ Policy loss """
            q_new_actions = jnp.minimum(
                forward_qf(train_params['qf1'], observations, new_actions),
                forward_qf(train_params['qf2'], observations, new_actions),
            )
            # self.policy_loss = (alpha*log_pi - q_new_actions).mean()
            policy_loss_regression = jax.lax.stop_gradient(-self.td3_alpha  / jnp.abs(q_new_actions).mean())
            policy_loss = policy_loss_regression * q_new_actions.mean() + jnp.mean(jnp.square(new_actions - actions))
            loss_collection['policy'] = policy_loss
            self.policy_loss = policy_loss
            
            """ Q function loss """
            q1_pred = forward_qf(train_params['qf1'], observations, actions)
            q2_pred = forward_qf(train_params['qf2'], observations, actions)

            new_next_actions, next_log_pi = forward_policy(
                train_params['policy'], next_observations
            )
            target_q_values = jnp.minimum(
                forward_qf(target_qf_params['qf1'], next_observations, new_next_actions),
                forward_qf(target_qf_params['qf2'], next_observations, new_next_actions),
            )

            td_target = jax.lax.stop_gradient(
                rewards + (1. - dones) * self.config.discount * target_q_values
            )
            qf1_bellman_loss = mse_loss(q1_pred, td_target)
            qf2_bellman_loss = mse_loss(q2_pred, td_target)
            loss_collection['qf1'] = qf1_bellman_loss
            loss_collection['qf2'] = qf2_bellman_loss
            return tuple(loss_collection[key] for key in self.model_keys), locals()

        train_params = {key: train_states[key].params for key in self.model_keys}
        (_, aux_values), grads = value_and_multi_grad(loss_fn, len(self.model_keys), has_aux=True)(train_params)
        policy_loss_gradient = jnp.linalg.norm(ravel_pytree(grads[self.model_keys.index("policy")]['policy'])[0])
        qf1_loss_gradient = jnp.linalg.norm(ravel_pytree(grads[self.model_keys.index("qf1")]['qf1'])[0])
        qf2_loss_gradient = jnp.linalg.norm(ravel_pytree(grads[self.model_keys.index("qf2")]['qf2'])[0])

        new_train_states = {
            key: train_states[key].apply_gradients(grads=grads[i][key])
            for i, key in enumerate(self.model_keys)
        }
        new_target_qf_params = {}
        new_target_qf_params['qf1'] = update_target_network(
            new_train_states['qf1'].params, target_qf_params['qf1'],
            self.config.soft_target_update_rate
        )
        new_target_qf_params['qf2'] = update_target_network(
            new_train_states['qf2'].params, target_qf_params['qf2'],
            self.config.soft_target_update_rate
        )
        metrics = collect_jax_metrics(
            aux_values,
            [ 'policy_loss',
              'qf1_loss', 'qf2_loss',
              'q1_pred', 'q2_pred',
              'target_q_values',
              'policy_loss_gradient',
              'qf1_loss_gradient',
              'qf2_loss_gradient'])

        metrics.update(policy_loss_gradient=policy_loss_gradient,
                       qf1_loss_gradient=qf1_loss_gradient,
                       qf2_loss_gradient=qf2_loss_gradient)

        return new_train_states, new_target_qf_params, metrics

    @property
    def model_keys(self):
        return self._model_keys

    @property
    def train_states(self):
        return self._train_states

    @property
    def train_params(self):
        return {key: self.train_states[key].params for key in self.model_keys}

    @property
    def total_steps(self):
        return self._total_steps
