from logging import Logger, getLogger
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
import pfrl

from pfrl.explorer import Explorer
from pfrl.replay_buffers import PrioritizedReplayBuffer
from pfrl.utils.batch_states import batch_states
from pfrl.agents.categorical_double_dqn import CategoricalDoubleDQN
from pfrl.agents.dqn import _mean_or_nan
from pfrl.agents.categorical_dqn import compute_value_loss, compute_weighted_value_loss

from buffer import GreedyReplayBuffer
from utils import batch_experiences


class ERBLearnCategoricalDoubleDQN(CategoricalDoubleDQN):
    def __init__(
        self,
        q_function: torch.nn.Module,
        optimizer: torch.optim.Optimizer,  # type: ignore  # somehow mypy complains
        replay_buffer: pfrl.replay_buffer.AbstractReplayBuffer,
        gamma: float,
        explorer: Explorer,
        gpu: Optional[int] = None,
        replay_start_size: int = 50000,
        minibatch_size: int = 32,
        update_interval: int = 1,
        target_update_interval: int = 10000,
        clip_delta: bool = True,
        phi: Callable[[Any], Any] = lambda x: x,
        target_update_method: str = "hard",
        soft_update_tau: float = 1e-2,
        n_times_update: int = 1,
        batch_accumulator: str = "mean",
        episodic_update_len: Optional[int] = None,
        logger: Logger = getLogger(__name__),
        batch_states: Callable[
            [Sequence[Any], torch.device, Callable[[Any], Any]], Any
        ] = batch_states,
        recurrent: bool = False,
        max_grad_norm: Optional[float] = None,
    ) -> None:
        super().__init__(
            q_function,
            optimizer,  # type: ignore  # somehow mypy complains
            replay_buffer,
            gamma,
            explorer,
            gpu,
            replay_start_size,
            minibatch_size,
            update_interval,
            target_update_interval,
            clip_delta,
            phi,
            target_update_method,
            soft_update_tau,
            n_times_update,
            batch_accumulator,
            episodic_update_len,
            logger,
            batch_states,
            recurrent,
            max_grad_norm,
        )

        self.max_rlen = 0
        self.last_scan_count = 0
    
    def _compute_loss(self, exp_batch, errors_out=None, errors_kl=None):
        """Compute a loss of categorical DQN."""
        y, t = self._compute_y_and_t(exp_batch)
        # Minimize the cross entropy
        # y is clipped to avoid log(0)
        eltwise_loss = -t * torch.log(torch.clamp(y, 1e-10, 1.0))

        if errors_out is not None:
            del errors_out[:]
            # The loss per example is the sum of the atom-wise loss
            # Prioritization by KL-divergence
            delta = eltwise_loss.sum(dim=1)
            delta = delta.detach().cpu().numpy()
            for e in delta:
                errors_out.append(e)
        
        if errors_kl is not None:
            del errors_kl[:]
            delta = (-t * torch.log(torch.clamp(y, 1e-10, 1.0))).sum(dim=1)
            delta = delta.detach().cpu().numpy()
            # delta = errors_out
            for e in delta:
                errors_kl.append(e)

        if "weights" in exp_batch:
            return compute_weighted_value_loss(
                eltwise_loss,
                y.shape[0],
                exp_batch["weights"],
                batch_accumulator=self.batch_accumulator,
            )
        else:
            return compute_value_loss(
                eltwise_loss, batch_accumulator=self.batch_accumulator
            )
    
    def update(
        self, experiences: List[List[Dict[str, Any]]], errors_out: Optional[list] = None
    ) -> None:
        has_weight = "weight" in experiences[0][0]
        exp_batch = batch_experiences(
            experiences,
            device=self.device,
            gamma=self.gamma,
        )
        if has_weight:
            exp_batch["weights"] = torch.tensor(
                [elem[0]["weight"] for elem in experiences],
                device=self.device,
                dtype=torch.float32,
            )
            if errors_out is None:
                errors_out = []
                errors_kl = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out, errors_kl=errors_kl)
        if has_weight:
            assert isinstance(self.replay_buffer, PrioritizedReplayBuffer)
            if isinstance(self.replay_buffer, GreedyReplayBuffer):
                with torch.no_grad():
                    q_values = self.model(exp_batch["state"].detach()).q_values
                    q_probs = F.softmax((q_values - torch.max(q_values, dim=-1, keepdim=True)[0]), dim=-1)
                    onpolicy_action = torch.gather(q_probs, 1, exp_batch["action"][:, None])[:, 0].detach().cpu().numpy()
                self.replay_buffer.update_errors(errors_out, onpolicy_action, kls=errors_kl)
            else:    
                self.replay_buffer.update_errors(errors_out)

        self.loss_record.append(float(loss.detach().cpu().numpy()))

        self.optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        self.optim_t += 1
            
    def _batch_observe_train(
        self,
        batch_obs: Sequence[Any],
        batch_reward: Sequence[float],
        batch_done: Sequence[bool],
        batch_reset: Sequence[bool],
    ) -> None:
        for i in range(len(batch_obs)):
            self.t += 1
            self._cumulative_steps += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    "state": self.batch_last_obs[i],
                    "action": self.batch_last_action[i],
                    "reward": batch_reward[i],
                    "next_state": batch_obs[i],
                    "next_action": None,
                    "is_state_terminal": batch_done[i],
                }

                if isinstance(self.replay_buffer, GreedyReplayBuffer):
                    self.replay_buffer.append(env_id=i, agent=self, **transition)
                else:
                    self.replay_buffer.append(env_id=i, **transition)
                
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)
            self.max_rlen = max(self.max_rlen, len(self.replay_buffer))
            self.last_scan_count += 1

    def get_statistics(self):
        return [
            ("average_q", _mean_or_nan(self.q_record)),
            ("average_loss", _mean_or_nan(self.loss_record)),
            ("cumulative_steps", self.cumulative_steps),
            ("n_updates", self.optim_t),
            ("rlen", len(self.replay_buffer)),
            ("count", self.replay_buffer.memory.count if hasattr(self.replay_buffer.memory, 'count') else 0),
            ("max_rlen", self.max_rlen),
        ]