import numpy as np
import torch
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.actor_critic.discrete_ddpg import (
    DiscreteDDPG,
)
from k_level_policy_gradients.src.distributions.gumbel import GumbelSoftmax


class GRUDiscreteDDPG(DiscreteDDPG):
    """
    Deep Deterministic Policy Gradient algorithm.
    "Continuous Control with Deep Reinforcement Learning".
    Lillicrap T. P. et al.. 2016.

    """

    def draw_action_hidden(self, state, hidden_state, action_mask=None):
        """
        Return the action to execute in the given state. It is the action
        returned by the policy or the action set by the algorithm (e.g. in the
        case of SARSA).

        Args:
            state (np.ndarray): the state where the agent is.
            action_mask (np.ndarray, None): the mask to apply to the action space.

        Returns:
            The action to be executed.

        """

        return self.policy.draw_action_hidden(state, hidden_state, action_mask)

    def _fit(self):
        if self._replay_memory.size > self._warmup_replay_size:
            episodes = self._replay_memory.get(self._batch_size)
            max_seq_len = max(len(episode) for episode in episodes)
            (
                batch_obs,
                batch_action_masks,
                batch_actions,
                batch_rewards,
                batch_next_obs,
                batch_next_action_masks,
                batch_absorbings,
                pad_masks,
            ) = ([], [], [], [], [], [], [], [])

            # Pad each episode to max_seq_len with zeros
            for episode in episodes:
                seq_len = len(episode)

                # Prepare the data arrays
                obs_seq = np.array([sample["obs"] for sample in episode])
                action_masks_seq = np.array(
                    [sample["action_mask"] for sample in episode]
                )
                actions_seq = np.array([sample["action"] for sample in episode])
                actions_seq_one_hot = np.eye(
                    self.mdp_info.action_space[self._idx_agent].n
                )[actions_seq].squeeze(1)
                rewards_seq = np.array([sample["reward"] for sample in episode])
                next_obs_seq = np.array([sample["next_obs"] for sample in episode])
                next_action_masks_seq = np.array(
                    [sample["next_action_mask"] for sample in episode]
                )
                absorbing_seq = np.array([sample["absorbing"] for sample in episode])

                # Pad to max_seq_len
                obs_pad = np.pad(
                    obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
                )
                action_masks_pad = np.pad(
                    action_masks_seq,
                    ((0, max_seq_len - seq_len), (0, 0)),
                    "constant",
                    constant_values=1,
                )
                actions_pad = np.pad(
                    actions_seq_one_hot,
                    ((0, max_seq_len - seq_len), (0, 0)),
                    "constant",
                )
                rewards_pad = np.pad(
                    rewards_seq, (0, max_seq_len - seq_len), "constant"
                )
                next_obs_pad = np.pad(
                    next_obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
                )
                next_action_masks_pad = np.pad(
                    next_action_masks_seq,
                    ((0, max_seq_len - seq_len), (0, 0)),
                    "constant",
                    constant_values=1,
                )
                absorbing_pad = np.pad(
                    absorbing_seq,
                    (0, max_seq_len - seq_len),
                    "constant",
                    constant_values=1,
                )

                mask = np.concatenate(
                    [np.ones(seq_len), np.zeros(max_seq_len - seq_len)]
                )

                batch_obs.append(obs_pad)
                batch_action_masks.append(action_masks_pad)
                batch_actions.append(actions_pad)
                batch_rewards.append(rewards_pad)
                batch_next_obs.append(next_obs_pad)
                batch_next_action_masks.append(next_action_masks_pad)
                batch_absorbings.append(absorbing_pad)
                pad_masks.append(mask)

            # Convert lists to numpy arrays for batch processing (sequence-first format)
            batch_obs = np.array(batch_obs).transpose(
                (1, 0, 2)
            )  # Shape: [seq_len, batch_size, obs_dim]
            batch_action_masks = np.array(batch_action_masks).transpose(
                1, 0, 2
            )  # Shape: [seq_len, batch_size, action_dim]
            batch_actions = np.array(batch_actions).transpose(
                (1, 0, 2)
            )  # Shape: [seq_len, batch_size, 1]
            batch_rewards = np.array(batch_rewards).transpose(
                (1, 0)
            )  # Shape: [seq_len, batch_size]
            batch_next_obs = np.array(batch_next_obs).transpose(
                (1, 0, 2)
            )  # Shape: [seq_len, batch_size, obs_dim]
            batch_next_action_masks = np.array(batch_next_action_masks).transpose(
                (1, 0, 2)
            )  # Shape: [seq_len, batch_size, action_dim]
            batch_absorbings = np.array(batch_absorbings).transpose(
                (1, 0)
            )  # Shape: [seq_len, batch_size]
            pad_masks = np.array(pad_masks).transpose(
                (1, 0)
            )  # Shape: [seq_len, batch_size]

            # Convert numpy arrays to torch tensors
            batch_obs_t = torch.tensor(batch_obs, dtype=torch.float32)
            batch_action_masks_t = torch.tensor(batch_action_masks, dtype=torch.bool)
            batch_actions_t = torch.tensor(batch_actions, dtype=torch.long)
            batch_rewards_t = torch.tensor(batch_rewards, dtype=torch.float32)
            batch_next_obs_t = torch.tensor(batch_next_obs, dtype=torch.float32)
            batch_next_action_masks_t = torch.tensor(
                batch_next_action_masks, dtype=torch.bool
            )
            batch_absorbings_t = torch.tensor(batch_absorbings, dtype=torch.bool)
            pad_masks_t = torch.tensor(pad_masks, dtype=torch.bool)

            # Get the action logits from the actor
            batch_logits = self.actor_approximator.predict(
                batch_obs_t,
                None,
                output_all=True,
                output_hidden=False,
                output_tensor=True,
            )
            batch_logits_mask = torch.where(
                batch_action_masks_t,
                batch_logits.squeeze(),
                torch.tensor(float("-inf")),
            )
            batch_actions_gumbel = GumbelSoftmax(
                logits=batch_logits_mask
            ).gumbel_softmax_sample()
            batch_actions_gumbel_hard = (
                torch.max(batch_actions_gumbel, dim=2, keepdim=True)
                == batch_actions_gumbel
            ).float()
            batch_actions_backprop = (
                batch_actions_gumbel_hard - batch_actions_gumbel
            ).detach() + batch_actions_gumbel

            # Critic update
            q_hat = self.critic_approximator.predict(
                batch_obs_t, batch_actions_t, output_tensor=True
            )
            q_next = self._next_q(
                batch_obs_t=batch_obs_t,
                batch_next_obs_t=batch_next_obs_t,
                batch_next_action_masks_t=batch_next_action_masks_t,
            )
            q_target = (
                batch_rewards_t + self.mdp_info.gamma * q_next * ~batch_absorbings_t
            ).detach()

            # Mask out the padded values
            q_hat = q_hat * pad_masks_t
            q_target = q_target * pad_masks_t

            # Compute loss and backpropagate
            critic_loss = F.mse_loss(q_hat, q_target, reduction="sum")
            critic_loss /= pad_masks_t.sum()
            self.critic_approximator._optimizer.zero_grad()
            critic_loss.backward()
            self.critic_approximator._optimizer.step()

            # Actor update
            actor_loss = self._loss(batch_obs_t, batch_actions_backprop, pad_masks_t)
            self._optimizer.zero_grad()
            actor_loss.backward()
            self._optimizer.step()

            self._n_updates += 1

            return actor_loss.item(), critic_loss.item()
        else:
            return 0, 0

    def _loss(self, batch_obs, batch_actions_one_hot, pad_masks):
        q = self.critic_approximator.predict(
            batch_obs, batch_actions_one_hot, output_tensor=True
        )

        q_valid = q[pad_masks]

        return -q_valid.mean()

    def _draw_target_action(
        self,
        batch_obs_t,
        batch_next_obs_t,
        batch_next_action_masks_t,
    ):
        batch_obs_all_timesteps = torch.cat([batch_obs_t[:1], batch_next_obs_t], dim=0)
        batch_logits_all_timesteps = self.target_actor_approximator.predict(
            batch_obs_all_timesteps,
            hidden=None,
            output_hidden=False,
            output_all=True,
            output_tensor=True,
        )
        batch_next_logits = batch_logits_all_timesteps[1:]  # don't need first state
        batch_next_logits_mask = torch.where(
            batch_next_action_masks_t,
            batch_next_logits,
            torch.tensor(float("-inf")),
        )
        target_action_probs = F.softmax(batch_next_logits_mask, dim=-1)
        batch_next_actions = F.one_hot(
            torch.argmax(target_action_probs, dim=-1),
            num_classes=batch_next_logits_mask.size(2),
        ).float()

        return batch_next_actions

    def _next_q(
        self,
        batch_obs_t,
        batch_next_obs_t,
        batch_next_action_masks_t,
    ):
        batch_next_actions = self._draw_target_action(
            batch_obs_t,
            batch_next_obs_t,
            batch_next_action_masks_t,
        )

        q_next = self.target_critic_approximator.predict(
            batch_next_obs_t, batch_next_actions, output_tensor=True
        )

        return q_next

    def _update_targets(self):
        """
        Update the target network.

        """
        self.target_actor_approximator.set_weights(
            self.actor_approximator.get_weights()
        )
        self.target_critic_approximator.set_weights(
            self.critic_approximator.get_weights()
        )

    def _post_load(self):
        self._actor_approximator = self.policy._approximator
        self._update_optimizer_parameters(
            self._actor_approximator.model.network.parameters()
        )
