import d3rlpy

from d3rlpy.algos.cql import DiscreteCQL
from d3rlpy.algos.torch.cql_impl import DiscreteCQLImpl, CQLImpl
from typing import Optional, Sequence

import torch


class CQLImpl_Penalty(CQLImpl):
    # For SAC implementation
    def __init__(self, compute_uncertainty = None, penalty_hyperparameter = 1,**kwargs):
        super().__init__(**kwargs)
        self.compute_uncertainty = compute_uncertainty
        self.penalty_hyperparameter = penalty_hyperparameter

    
    def compute_critic_loss(
        self, batch, q_tpn: torch.Tensor, no_grad: bool = False
    ):
        if no_grad:
            with torch.no_grad():
                return super().compute_critic_loss(batch, q_tpn, no_grad) \
                        - self.penalty_hyperparameter * self.compute_uncertainty(batch.observations, batch.actions, policy=self._predict_best_action)
        return super().compute_critic_loss(batch, q_tpn, no_grad) \
            - self.penalty_hyperparameter * self.compute_uncertainty(batch.observations, batch.actions, policy=self._predict_best_action)
        

class DiscreteCQLImpl_Penalty(DiscreteCQLImpl):
    def __init__(self, compute_uncertainty = None, penalty_hyperparameter = 1,**kwargs):
        super().__init__(**kwargs)
        self.compute_uncertainty = compute_uncertainty
        self.penalty_hyperparameter = penalty_hyperparameter


    def compute_target(self, batch):
        with torch.no_grad():
            target = super().compute_target(batch)
            u = self.compute_uncertainty(batch.observations, batch.actions, policy=self._predict_best_action)
            out = target \
                - self.penalty_hyperparameter * torch.unsqueeze(u, 1).expand(-1, target.shape[-1])
            return out
        


class DelphicCQL(DiscreteCQL):
    def __init__(self, compute_uncertainty=None,
                    penalty_hyperparameter=1.0, **kwargs):
        super().__init__(**kwargs)
        self._compute_uncertainty = compute_uncertainty
        self._penalty_hyperparameter = penalty_hyperparameter

    def _create_impl(
        self, observation_shape: Sequence[int], action_size: int
    ) -> None:
        self._impl = DiscreteCQLImpl_Penalty(
            observation_shape=observation_shape,
            action_size=action_size,
            learning_rate=self._learning_rate,
            optim_factory=self._optim_factory,
            encoder_factory=self._encoder_factory,
            q_func_factory=self._q_func_factory,
            gamma=self._gamma,
            n_critics=self._n_critics,
            alpha=self._alpha,
            use_gpu=self._use_gpu,
            scaler=self._scaler,
            reward_scaler=self._reward_scaler,
            compute_uncertainty=self._compute_uncertainty,
            penalty_hyperparameter=self._penalty_hyperparameter,
        )
        self._impl.build()
