import torch
import numpy as np
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.value.abstract_dqn import AbstractDQN


class GRUDQN(AbstractDQN):
    def draw_action_hidden(self, state, hidden_state, action_mask=None):
        """
        Return the action to execute in the given state and hidden
        state for a recurrent network agent. It is the action returned by the policy or the action
        set by the algorithm (e.g. in the case of SARSA).

        Args:
            state (np.ndarray): the state where the agent is.
            hidden_state (np.ndarray): the hidden state of the recurrent network.
            action_mask (np.ndarray, None): the mask to apply to the action space.

        Returns:
            The action to be executed.

        """
        return self.policy.draw_action_hidden(state, hidden_state, action_mask)

    def _fit(self):
        if self._replay_memory.size > self._warmup_replay_size:
            episodes = self._replay_memory.get(self._batch_size)
            max_seq_len = max(len(episode) for episode in episodes)

            (
                batch_obs_t,
                batch_actions_t,
                batch_rewards_t,
                batch_next_obs_t,
                batch_next_action_masks_t,
                batch_absorbings_t,
                pad_masks_t,
            ) = self.get_episodes(episodes, max_seq_len)

            # Get Q predictions and targets for DQN
            q_obs = self.approximator.predict(
                batch_obs_t,
                hidden=None,  # start from zero hidden state
                output_hidden=False,
                output_all=True,
                output_tensor=True,
            )
            q_hat = torch.squeeze(q_obs.gather(-1, batch_actions_t))
            q_next = self._next_q(
                batch_obs_t,
                batch_next_obs_t,
                batch_next_action_masks_t,
            )
            if batch_absorbings_t.any():
                action_space_dim = self.mdp_info.action_space[self._idx_agent].n
                expanded_absorbings = batch_absorbings_t.unsqueeze(-1).expand(
                    -1, -1, action_space_dim
                )
                q_next *= ~expanded_absorbings

            q_target = (batch_rewards_t + self.mdp_info.gamma * q_next).detach()

            # Mask out the padded values
            q_hat = q_hat * pad_masks_t
            q_target = q_target * pad_masks_t

            # Compute loss and backpropagate
            loss = F.smooth_l1_loss(q_hat, q_target, reduction="sum")
            loss /= pad_masks_t.sum()
            self.approximator._optimizer.zero_grad()
            loss.backward()
            self.approximator._optimizer.step()

            return loss.item()
        else:
            return 0

    def get_episodes(self, episodes, max_seq_len):
        """
        Batch information for the update
        """
        (
            batch_obs,
            batch_actions,
            batch_rewards,
            batch_next_obs,
            batch_next_action_masks,
            batch_absorbings,
            pad_masks,
        ) = ([], [], [], [], [], [], [])

        # Pad each episode to max_seq_len with zeros
        for episode in episodes:
            seq_len = len(episode)

            # Prepare the data arrays
            obs_seq = np.array([sample["obs"] for sample in episode])
            actions_seq = np.array([sample["action"] for sample in episode])
            rewards_seq = np.array([sample["reward"] for sample in episode])
            next_obs_seq = np.array([sample["next_obs"] for sample in episode])
            next_action_masks_seq = np.array(
                [sample["next_action_mask"] for sample in episode]
            )
            absorbing_seq = np.array([sample["absorbing"] for sample in episode])

            # Pad to max_seq_len
            obs_pad = np.pad(obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant")
            actions_pad = np.pad(
                actions_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            rewards_pad = np.pad(rewards_seq, (0, max_seq_len - seq_len), "constant")
            next_obs_pad = np.pad(
                next_obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            next_action_masks_pad = np.pad(
                next_action_masks_seq,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
                constant_values=1,
            )
            absorbing_pad = np.pad(
                absorbing_seq,
                (0, max_seq_len - seq_len),
                "constant",
                constant_values=1,
            )

            mask = np.concatenate([np.ones(seq_len), np.zeros(max_seq_len - seq_len)])

            batch_obs.append(obs_pad)
            batch_actions.append(actions_pad)
            batch_rewards.append(rewards_pad)
            batch_next_obs.append(next_obs_pad)
            batch_next_action_masks.append(next_action_masks_pad)
            batch_absorbings.append(absorbing_pad)
            pad_masks.append(mask)

        # Convert lists to numpy arrays for batch processing (sequence-first format)
        batch_obs = np.array(batch_obs).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, obs_dim]
        batch_actions = np.array(batch_actions).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, 1]
        batch_rewards = np.array(batch_rewards).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]
        batch_next_obs = np.array(batch_next_obs).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, obs_dim]
        batch_next_action_masks = np.array(batch_next_action_masks).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, action_dim]
        batch_absorbings = np.array(batch_absorbings).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]
        pad_masks = np.array(pad_masks).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]

        # Convert numpy arrays to torch tensors
        batch_obs_t = torch.tensor(batch_obs, dtype=torch.float32)
        batch_actions_t = torch.tensor(batch_actions, dtype=torch.long)
        batch_rewards_t = torch.tensor(batch_rewards, dtype=torch.float32)
        batch_next_obs_t = torch.tensor(batch_next_obs, dtype=torch.float32)
        batch_next_action_masks_t = torch.tensor(
            batch_next_action_masks, dtype=torch.bool
        )
        batch_absorbings_t = torch.tensor(batch_absorbings, dtype=torch.bool)
        pad_masks_t = torch.tensor(pad_masks, dtype=torch.bool)

        return (
            batch_obs_t,
            batch_actions_t,
            batch_rewards_t,
            batch_next_obs_t,
            batch_next_action_masks_t,
            batch_absorbings_t,
            pad_masks_t,
        )

    def _next_q(
        self,
        batch_obs_t,
        batch_next_obs_t,
        batch_next_action_masks=None,
    ):
        batch_obs_all_timesteps = torch.cat([batch_obs_t[:1], batch_next_obs_t], dim=0)
        batch_qs_all_timesteps = self.target_approximator.predict(
            batch_obs_all_timesteps,
            hidden=None,
            output_hidden=False,
            output_all=True,
            output_tensor=True,
        )
        batch_qs_next = batch_qs_all_timesteps[1:]  # exclude first state
        if batch_next_action_masks is not None:
            q_next_mask = torch.where(
                batch_next_action_masks, batch_qs_next, torch.tensor(float("-inf"))
            )
            max_q_values, _ = q_next_mask.max(2)
        else:
            max_q_values, _ = batch_qs_next.max(2)
        return max_q_values  # only needed for next_obs
