import numpy as np
import torch
from k_level_policy_gradients.src.mixers.mixers import QMixer
from k_level_policy_gradients.src.utils.torch import get_weights, set_weights
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.agent import Agent
from itertools import chain


class QMIX(Agent):
    """
    Instantiates a QMIX mixing network and hypernetwork layers.
    """

    def __init__(
        self,
        mdp_info,
        idx_agent,
        batch_size,
        replay_memory,
        target_update_frequency,
        tau,
        warmup_replay_size,
        target_update_mode,
        mixing_embed_dim,
        optimizer_params,
        scale_loss,
        grad_norm_clip,
        obs_last_action,
        host_agents,
        use_cuda=False,
    ):
        super().__init__(mdp_info, policy=None, idx_agent=idx_agent)

        self._batch_size = batch_size
        self._replay_memory = replay_memory
        self._target_update_frequency = target_update_frequency
        self._tau = tau
        self._warmup_replay_size = warmup_replay_size
        self._target_update_mode = target_update_mode
        self._scale_loss = scale_loss
        self._grad_norm_clip = grad_norm_clip
        self._obs_last_action = obs_last_action
        self._host_agents = host_agents  # The agents using this mixing network
        self._use_cuda = use_cuda

        self._n_updates = 0

        self._state_shape_int = int(np.prod(self.mdp_info.state_space.shape))

        self._mixer = QMixer(
            state_shape=mdp_info.state_space.shape,
            mixing_embed_dim=mixing_embed_dim,
            n_agents=mdp_info.n_agents,
        )
        self._target_mixer = QMixer(
            state_shape=mdp_info.state_space.shape,
            mixing_embed_dim=mixing_embed_dim,
            n_agents=mdp_info.n_agents,
        )

        self.shared_params_bool = self._host_agents[-1]._primary_agent is not None

        # Assume agents share network parameters (therefore only need primary agent's network)
        # Use a single optimizer for the mixer and shared agent network
        self.params = list(
            chain(
                host_agents[0].approximator.network.parameters(),
                self._mixer.parameters(),
            )
        )
        self._optimizer = optimizer_params["class"](
            self.params, **optimizer_params["params"]
        )

        self.update_target_mixer()

        self._add_save_attr(
            _batch_size="primitive",
            _target_update_frequency="primitive",
            _tau="primitive",
            _warmup_replay_size="primitive",
            _replay_memory="mushroom!",
            _n_updates="primitive",
            _mixer="torch",
            _target_mixer="torch",
            _actor_optimizer="torch",
            _critic_optimizer="torch",
            _use_cuda="primitive",
        )

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            episodes = self._replay_memory.get(self._batch_size)
            max_seq_len = max(len(episode) for episode in episodes)

            # Get global batch information
            (
                batch_states_t,
                batch_rewards_t,
                batch_next_states_t,
                batch_absorbings_t,
                pad_masks_t,
            ) = self.get_mixer_episodes(episodes, max_seq_len)

            # Get agent-specific batch information
            batch_obs_t_list = []
            batch_action_masks_t_list = []
            batch_actions_t_list = []
            batch_next_obs_t_list = []
            batch_next_action_masks_t_list = []
            for idx_agent, _ in enumerate(self._host_agents):
                (
                    batch_obs_t_agent,
                    batch_action_masks_t_agent,
                    batch_actions_t_agent,
                    batch_next_obs_t_agent,
                    batch_next_action_masks_t_agent,
                ) = self.get_agent_episodes(episodes, idx_agent, max_seq_len)
                batch_obs_t_list.append(batch_obs_t_agent)
                batch_action_masks_t_list.append(batch_action_masks_t_agent)
                batch_actions_t_list.append(batch_actions_t_agent)
                batch_next_obs_t_list.append(batch_next_obs_t_agent)
                batch_next_action_masks_t_list.append(batch_next_action_masks_t_agent)

            q_hats = []
            q_nexts = []
            for idx_agent, _ in enumerate(self._host_agents):
                q_hat, q_next = self.get_qs(
                    idx_agent,
                    batch_obs_t_list[idx_agent],
                    batch_actions_t_list[idx_agent],
                    batch_next_obs_t_list[idx_agent],
                    batch_next_action_masks_t_list[idx_agent],
                )
                q_hats.append(q_hat)
                q_nexts.append(q_next)

            q_hat = torch.stack(q_hats, dim=-1).unsqueeze(-1)
            q_next = torch.stack(q_nexts, dim=-1).unsqueeze(-1)

            state_shape_int = int(np.prod(self.mdp_info.state_space.shape))
            q_tot = self.mix(q_hat, batch_states_t.reshape(-1, state_shape_int))
            q_tot_next = self.target_mix(
                q_next, batch_next_states_t.reshape(-1, state_shape_int)
            )
            q_tot_target = (
                batch_rewards_t + self.mdp_info.gamma * q_tot_next * ~batch_absorbings_t
            ).detach()

            # Compute loss and backpropagate
            q_tot = q_tot * pad_masks_t
            q_tot_target = q_tot_target * pad_masks_t
            loss = F.mse_loss(q_tot, q_tot_target, reduction="sum")
            loss /= pad_masks_t.sum()
            if self._scale_loss:
                loss /= self.mdp_info.n_agents
            self._optimizer.zero_grad()
            loss.backward()
            if self._grad_norm_clip is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.params, self._grad_norm_clip
                )
            self._optimizer.step()

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_mixer_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_mixer()

            return loss.item(), loss.item()
        else:
            return 0, 0

    def get_mixer_episodes(self, episodes, max_seq_len):
        """
        Global batch information for the mixer and agents.
        """
        (
            batch_states,
            batch_rewards,
            batch_next_states,
            batch_absorbings,
            pad_masks,
        ) = ([], [], [], [], [])
        for episode in episodes:
            seq_len = len(episode)

            state_seq = np.array([sample["state"] for sample in episode])
            rewards_seq = np.array(
                [sample["rewards"][0] for sample in episode]
            )  # all agents have same reward, so we just take agent 0's reward
            next_state_seq = np.array([sample["next_state"] for sample in episode])
            absorbing_seq = np.array([sample["absorbing"] for sample in episode])
            mask = np.concatenate([np.ones(seq_len), np.zeros(max_seq_len - seq_len)])

            # Pad to max_seq_len
            state_pad = np.pad(
                state_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            rewards_pad = np.pad(rewards_seq, (0, max_seq_len - seq_len), "constant")
            next_state_pad = np.pad(
                next_state_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            absorbing_pad = np.pad(
                absorbing_seq,
                (0, max_seq_len - seq_len),
                "constant",
                constant_values=1,
            )

            # Append to the batch
            batch_states.append(state_pad)
            batch_rewards.append(rewards_pad)
            batch_next_states.append(next_state_pad)
            batch_absorbings.append(absorbing_pad)
            pad_masks.append(mask)

        # Transpose to [seq_len, batch_size, ...] format
        batch_states = np.array(batch_states).transpose(
            1, 0, 2
        )  # Shape: [seq_len, batch_size, state_dim]
        batch_rewards = np.array(batch_rewards).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]
        batch_next_states = np.array(batch_next_states).transpose(
            1, 0, 2
        )  # Shape: [seq_len, batch_size, state_dim]
        batch_absorbings = np.array(batch_absorbings).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]
        pad_masks = np.array(pad_masks).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]

        # Convert to torch tensors
        batch_states_t = torch.tensor(batch_states, dtype=torch.float32)
        batch_rewards_t = torch.tensor(batch_rewards, dtype=torch.float32).unsqueeze(-1)
        batch_next_states_t = torch.tensor(batch_next_states, dtype=torch.float32)
        batch_absorbings_t = torch.tensor(batch_absorbings, dtype=torch.bool).unsqueeze(
            -1
        )
        pad_masks_t = torch.tensor(pad_masks, dtype=torch.bool).unsqueeze(-1)

        return (
            batch_states_t,
            batch_rewards_t,
            batch_next_states_t,
            batch_absorbings_t,
            pad_masks_t,
        )

    def get_agent_episodes(self, episodes, idx_agent, max_seq_len):
        """
        Agent-specific episodes for the agent with index idx_agent.
        """
        # prepare data arrays
        (
            batch_obs,
            batch_action_masks,
            batch_actions,
            batch_next_obs,
            batch_next_action_masks,
        ) = ([], [], [], [], [])

        for episode in episodes:
            seq_len = len(episode)

            # Prepare the data arrays
            obs_seq = np.array([sample["obs"][idx_agent] for sample in episode])
            action_mask_seq = np.array(
                [sample["action_masks"][idx_agent] for sample in episode]
            )
            actions_seq = np.array([sample["actions"][idx_agent] for sample in episode])
            next_obs_seq = np.array(
                [sample["next_obs"][idx_agent] for sample in episode]
            )
            next_action_masks_seq = np.array(
                [sample["next_action_masks"][idx_agent] for sample in episode]
            )

            # Pad to max_seq_len
            obs_pad = np.pad(obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant")
            action_masks_pad = np.pad(
                action_mask_seq,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
                constant_values=1,
            )
            actions_pad = np.pad(
                actions_seq,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
            )
            next_obs_pad = np.pad(
                next_obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            next_action_masks_pad = np.pad(
                next_action_masks_seq,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
                constant_values=1,
            )

            # Append to the batch
            batch_obs.append(obs_pad)
            batch_action_masks.append(action_masks_pad)
            batch_actions.append(actions_pad)
            batch_next_obs.append(next_obs_pad)
            batch_next_action_masks.append(next_action_masks_pad)

        # Converts lists to numpy arrays with shape [seq_len, batch_size, ...]
        batch_obs = np.array(batch_obs).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, obs_dim]
        batch_action_masks = np.array(batch_action_masks).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, action_dim]
        batch_actions = np.array(batch_actions).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, 1]
        batch_next_obs = np.array(batch_next_obs).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, obs_dim]
        batch_next_action_masks = np.array(batch_next_action_masks).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, action_dim]

        # Convert to torch tensors
        batch_obs_t = torch.tensor(batch_obs, dtype=torch.float32)
        batch_action_masks_t = torch.tensor(batch_action_masks, dtype=torch.bool)
        batch_actions_t = torch.tensor(batch_actions, dtype=torch.long)
        batch_next_obs_t = torch.tensor(batch_next_obs, dtype=torch.float32)
        batch_next_action_masks_t = torch.tensor(
            batch_next_action_masks, dtype=torch.bool
        )

        return (
            batch_obs_t,
            batch_action_masks_t,
            batch_actions_t,
            batch_next_obs_t,
            batch_next_action_masks_t,
        )

    def get_qs(
        self,
        idx_agent,
        batch_obs_t,
        batch_actions_t,
        batch_next_obs_t,
        batch_next_action_masks_t,
    ):
        """
        Get q_hat and q_next for the agent with index idx_agent.
        """
        # Get Q predictions and targets for DQN
        q_obs = self._host_agents[idx_agent].approximator.predict(
            batch_obs_t,
            hidden=None,
            output_hidden=False,
            output_all=True,
            output_tensor=True,
        )
        q_hat = torch.squeeze(q_obs.gather(-1, batch_actions_t))

        q_next = self._host_agents[idx_agent]._next_q(
            batch_obs_t,
            batch_next_obs_t,
            batch_next_action_masks_t,
        )

        return q_hat, q_next

    def mix(self, chosen_action_value, state):
        return self._mixer(chosen_action_value, state)

    def target_mix(self, chosen_action_value, state):
        return self._target_mixer(chosen_action_value, state)

    def update_target_mixer(self):
        w = get_weights(self._mixer.parameters())
        set_weights(self._target_mixer.parameters(), w, use_cuda=self._use_cuda)

    def update_target_mixer_soft(self):
        weights = self._tau * self.get_mixer_weights()
        weights += (1 - self._tau) * get_weights(self._target_mixer.parameters())
        set_weights(self._target_mixer.parameters(), weights, use_cuda=self._use_cuda)

    def get_mixer_weights(self):
        return get_weights(self._mixer.parameters())

    def compute_grad_norm(self):
        total_norm = 0.0
        for layer in self.params:
            layer_norm = layer.grad.data.norm(2)
            total_norm += layer_norm.item() ** 2
        total_norm = total_norm**0.5
        return total_norm
