import torch
import torch.nn.functional as F

from expground.types import DataArray, Dict, Any
from expground.logger import Log
from expground.utils import data
from expground.utils.data import EpisodeKeys
from expground.algorithms import misc
from expground.algorithms.loss_func import LossFunc


class DQNLoss(LossFunc):
    def __init__(self, mute_critic_loss: bool = False):
        # mute critic loss will not work in this case
        super(DQNLoss, self).__init__(mute_critic_loss=False)

    def setup_optimizers(self, *args, **kwargs):
        """Setup optimizers. This function will be called when there is a new policy is assigned to this loss function."""
        # print("\t - set optimizers since new policy registerred")
        optim_cls = getattr(torch.optim, self._params["optimizer"])
        self.optimizers = {
            "critic": optim_cls(
                self.policy.critic.parameters(), lr=self._params["critic_lr"]
            )
        }

    def step(self) -> Any:
        self.loss[0].backward()
        self.optimizers["critic"].step()

    @data.tensor_cast(callback=lambda x: Log.debug(f"Training info: {x}"))
    def __call__(self, batch: Dict[str, DataArray]) -> Dict[str, Any]:
        self.loss = []

        device = "cuda" if self.policy.use_cuda else "cpu"
        actions = batch[EpisodeKeys.ACTION.value].to(device)
        observations = batch[EpisodeKeys.OBSERVATION.value].to(device)
        next_observations = batch[EpisodeKeys.NEXT_OBSERVATION.value].to(device)
        state_action_values = (
            self.policy.critic(observations)
            .gather(-1, actions.long().view((-1, 1)))
            .view(-1)
        )
        next_state_q = self.policy.target_critic(next_observations)
        next_action_mask = batch.get("next_action_mask", None)

        if next_action_mask is not None:
            illegal_action_mask = 1.0 - next_action_mask.to(device)
            # give very low value to illegal action logits
            illegal_action_logits = -torch.FloatTensor(illegal_action_mask) * 1e9
            next_state_q += illegal_action_logits.to(device)

        next_state_action_values = next_state_q.max(-1)[0]
        expected_state_values = (
            batch[EpisodeKeys.REWARD.value].to(device)
            + self._params["gamma"]
            * (1.0 - batch[EpisodeKeys.DONE.value].to(device))
            * next_state_action_values
        )
        # print(state_action_values.max().item(), state_action_values.min().item(), expected_state_values.max().item(), expected_state_values.min().item())
        loss = F.mse_loss(state_action_values, expected_state_values.detach())
        self.loss.append(loss)

        return {
            "loss": loss.detach().item(),
            "mean_target": expected_state_values.mean().cpu().item(),
            "mean_eval": state_action_values.mean().cpu().item(),
            "min_eval": state_action_values.min().cpu().item(),
            "max_eval": state_action_values.max().cpu().item(),
            "max_target": expected_state_values.max().cpu().item(),
            "min_target": expected_state_values.min().cpu().item(),
            "mean_reward": batch[EpisodeKeys.REWARD.value].mean().cpu().item(),
            "min_reward": batch[EpisodeKeys.REWARD.value].min().cpu().item(),
            "max_reward": batch[EpisodeKeys.REWARD.value].max().cpu().item(),
        }
