from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict

from minto.networks.architectures.dqn import DQNNet
from minto.sample_collection.fixed_replay_buffer import FixedReplayBuffer
from minto.sample_collection.replay_buffer import ReplayElement
from minto.networks.cql import CQL


class FRCQL(CQL):
    def __init__(
        self,
        key: jax.random.PRNGKey,
        observation_dim,
        n_actions,
        features: list,
        layer_norm: bool,
        architecture_type: str,
        learning_rate: float,
        gamma: float,
        update_horizon: int,
        target_update_frequency: int,
        alpha_cql: float,
        adam_eps: float = 0.0003125,
        target_function: str = "default",
        kappa: float = 1.0,
    ):  
        # for Functional Regularized CQL (FR-CQL)
        self.kappa = kappa
        print(f"Using FRCQL with kappa={self.kappa}")
        assert target_function == "default", "DoubleCQL only works with the default target function in this class."
        super().__init__(
            key, observation_dim, n_actions, features, layer_norm, architecture_type, learning_rate,
            gamma, update_horizon, target_update_frequency, alpha_cql, adam_eps, target_function
        )
        self.cumulated_loss = np.zeros(3)  # one entry each for TD, BC, FR component
        self.cumulated_info.update({"q_value_target": 0}) # remove fr_loss since added later
    
    def update_target_params(self, **kwargs):
        self.target_params = self.params.copy()

        logs = {
            "td_loss": self.cumulated_loss[0] / self.target_update_frequency,
            "bc_loss": self.alpha_cql * self.cumulated_loss[1] / self.target_update_frequency,
            "fr_loss": self.kappa/2 * self.cumulated_loss[2] / self.target_update_frequency,
            "variance": self.cumulated_variance / self.target_update_frequency,
        }
        self.cumulated_loss = np.zeros_like(self.cumulated_loss)
        self.cumulated_variance = 0

        logs.update({
                k: v / self.target_update_frequency
                for k, v in self.cumulated_info.items()
            })
        self.cumulated_info = {k: 0 for k in self.cumulated_info.keys()}

        return logs
    
    def loss(self, params: FrozenDict, params_target: FrozenDict, sample: ReplayElement):
        
        target, info = self.compute_target(params_target, params, sample)
        q_values = self.network.apply(params, sample.state)
        q_value_target = self.network.apply(params_target, sample.state)[sample.action]
        td_loss = jnp.square(target - q_values[sample.action])
        bc_loss = jax.scipy.special.logsumexp(q_values, axis=-1) - q_values[sample.action]
        fr_loss = jnp.square(q_values[sample.action] - q_value_target)

        # add Q(s,a) and y to info
        info.update({"q_value": q_values[sample.action]})
        info.update({"target": target})
        info.update({"q_value_target": q_value_target})

        return (
            td_loss + self.alpha_cql * bc_loss + self.kappa/2 * fr_loss,
            jnp.array([td_loss, bc_loss, fr_loss]),
            target**2 - target * q_values[sample.action],
            info
        )
    
    def compute_target(self, target_params: FrozenDict, online_params: FrozenDict, sample: ReplayElement):
        return sample.reward + (1 - sample.is_terminal) * (self.gamma**self.update_horizon) * jnp.max(
            jax.lax.stop_gradient(self.network.apply(online_params, sample.next_state))
        ),  {}