from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict

from minto.networks.architectures.dqn import DQNNet
from minto.sample_collection.fixed_replay_buffer import FixedReplayBuffer
from minto.sample_collection.replay_buffer import ReplayElement


class MaxMinCQL:
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        n_actions,
        features: list,
        layer_norm: bool,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        target_update_frequency: int,
        alpha_cql: float,
        adam_eps: float = 0.0003125,
        target_function: str = "default",
        n_qs: int = 2,
    ):
        self.n_qs = n_qs
        print(f"Using MaxMinCQL with n_qs={self.n_qs}")

        self.network_key = key
        self.network = DQNNet(features, architecture_type, n_actions, layer_norm)
        keys = jax.random.split(key, self.n_qs)
        self.params = jax.vmap(self.network.init, (0, None))(
            keys, jnp.zeros(observation_dim, dtype=jnp.float32)
        )

        self.optimizer = optax.adam(learning_rate, eps=adam_eps)
        self.optimizer_state = self.optimizer.init(self.params)
        self.target_params = self.params.copy()

        self.gamma = gamma
        self.update_horizon = update_horizon
        self.target_update_frequency = target_update_frequency
        self.cumulated_loss = np.zeros(2)  # one entry each for TD and BC component
        self.cumulated_variance = 0
        self.alpha_cql = alpha_cql

        self.cumulated_info = {
                "grad_norm": 0,
                "param_norm": 0,
                "online_fraction": 0,
                "q_value": 0,
                "target": 0,
                }
        
        self.compute_target_fn = {"default": self.compute_target,
                                  "min": self.compute_target_min, # minto target function
                                  }[target_function]
        print(f"Using {target_function} target function")

    @partial(jax.jit, static_argnames="self")
    def apply_multiple_updates(self, params, params_target, optimizer_state, batches, key):
        def apply_single_update(state, batch_key):
            batch, key = batch_key
            params, optimizer_state, loss, variance, info = self.learn_on_batch(state[0], params_target, state[1], batch, key)
            return (params, optimizer_state), (loss, variance, info)

        # Convert the list of batch to a list single batch where each element
        # has the shape (n_batch, batch_size) + (element_shape,)
        batches = jax.tree.map(lambda *batch: jnp.stack(batch), *batches)
        keys = jax.random.split(key, jax.tree.leaves(batches)[0].shape[0])
        (final_params, final_optimizer_state), (loss, variance, info) = jax.lax.scan(
            apply_single_update, (params, optimizer_state), (batches, keys)
        )

        info = {k: v.sum(axis=0) for k, v in info.items()}  # average the info over the batch

        return final_params, final_optimizer_state, loss.sum(axis=0), variance.sum(axis=0), info

    def n_updates_online_params(self, n_updates: int, replay_buffer: FixedReplayBuffer):
        batches = replay_buffer.sample(n_updates)
        self.network_key, key = jax.random.split(self.network_key)
        self.params, self.optimizer_state, loss, variance, info = self.apply_multiple_updates(
            self.params, self.target_params, self.optimizer_state, batches, key
        )
        self.cumulated_loss += loss
        self.cumulated_variance += variance

        # cumulate the info
        for k in info.keys():
            if k in self.cumulated_info:
                self.cumulated_info[k] += info[k]

    def update_target_params(self, **kwargs):
        self.target_params = self.params.copy()

        logs = {
            "td_loss": self.cumulated_loss[0] / self.target_update_frequency,
            "bc_loss": self.alpha_cql * self.cumulated_loss[1] / self.target_update_frequency,
            "variance": self.cumulated_variance / self.target_update_frequency,
        }
        self.cumulated_loss = np.zeros_like(self.cumulated_loss)
        self.cumulated_variance = 0

        logs.update({
                k: v / self.target_update_frequency
                for k, v in self.cumulated_info.items()
            })
        self.cumulated_info = {k: 0 for k in self.cumulated_info.keys()}

        return logs

    def learn_on_batch(self, params: FrozenDict, params_target: FrozenDict, optimizer_state, batch_samples, key):
        grad_loss, (losses, variance, info) = jax.grad(self.loss_on_batch, has_aux=True)(params, params_target, batch_samples, key)
        updates, optimizer_state = self.optimizer.update(grad_loss, optimizer_state)
        params = optax.apply_updates(params, updates)

        info.update({"grad_norm": optax.global_norm(grad_loss)})
        info.update({"param_norm": optax.global_norm(params)})

        return params, optimizer_state, losses, variance, info

    def loss_on_batch(self, params: FrozenDict, params_target: FrozenDict, samples, key):
        idx = jax.random.randint(key, (), 0, self.n_qs)
        params = jax.tree_util.tree_map(lambda x: x[idx], params)
        losses, loss_terms, variance, info = jax.vmap(self.loss, in_axes=(None, None, 0))(params, params_target, samples)
        info = {k: v.mean() for k, v in info.items()}  # average the info over the batch
        return losses.mean(axis=0), (loss_terms.mean(axis=0), variance.mean(axis=0), info)

    def loss(self, params: FrozenDict, params_target: FrozenDict, sample: ReplayElement):
        q_values = self.network.apply(params, sample.state)
        target, info = self.compute_target_fn(params_target, params, sample)
        td_loss = jnp.square(target - q_values[sample.action])
        bc_loss = jax.scipy.special.logsumexp(q_values, axis=-1) - q_values[sample.action]

        # add Q(s,a) and y to info
        info.update({"q_value": q_values[sample.action]})
        info.update({"target": target})

        return (
            td_loss + self.alpha_cql * bc_loss,
            jnp.array([td_loss, bc_loss]),
            target**2 - target * q_values[sample.action],
            info
        )

    def compute_target(self, target_params: FrozenDict, online_params: FrozenDict, sample: ReplayElement):
        qss_target = jax.vmap(self.network.apply, (0, None))(
            target_params, sample.next_state
        )
        qs_target = jnp.min(qss_target, axis=0)
        q_next = jnp.max(qs_target)
        return sample.reward + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * q_next,  {}

    # minto target function
    def compute_target_min(self, target_params: FrozenDict, online_params: FrozenDict, sample: ReplayElement):
        qss_target = jax.vmap(self.network.apply, (0, None))(
            target_params, sample.next_state
        )
        qs_online = self.network.apply(online_params, sample.next_state)

        qs_target = jnp.min(qss_target, axis=0)
        qs = jnp.minimum(qs_target, qs_online)
        q_next = jnp.max(qs)

        info = {
            "online_fraction": jnp.equal(
                q_next,
                jnp.max(
                    jnp.where(qs_online <= qs_target, qs_online, -jnp.inf)
                ),
            ).astype(jnp.float32)
        }

        return sample.reward + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * q_next,  info
    

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params: FrozenDict, state: jnp.ndarray, **kwargs):
        # computes the best action for a single state
        qss = jax.vmap(self.network.apply, (0, None))(params, state)
        qs = jnp.min(qss, axis=0)
        action = jnp.argmax(qs)
        return action

    def get_model(self):
        return {"params": self.params}
