import numpy as np
import torch
from k_level_policy_gradients.src.mixers.mixers import QMixer
from k_level_policy_gradients.src.utils.torch import get_weights, set_weights
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.agent import Agent
from k_level_policy_gradients.src.distributions.gumbel import GumbelSoftmax

from itertools import chain


class FACMAC(Agent):
    """
    Instantiates a FACMAC mixing network and hypernetwork layers.
    """

    def __init__(
        self,
        mdp_info,
        idx_agent,
        batch_size,
        replay_memory,
        target_update_frequency,
        tau,
        warmup_replay_size,
        target_update_mode,
        mixing_embed_dim,
        actor_optimizer_params,
        critic_optimizer_params,
        scale_critic_loss,
        scale_actor_loss,
        centralized_critic,
        grad_norm_clip,
        obs_last_action,
        critic_obs_last_action,
        critic_agent_encoding,
        host_agents,
        use_cuda=False,
    ):
        super().__init__(mdp_info, policy=None, idx_agent=idx_agent)

        self._batch_size = batch_size
        self._replay_memory = replay_memory
        self._target_update_frequency = target_update_frequency
        self._tau = tau
        self._warmup_replay_size = warmup_replay_size
        self._target_update_mode = target_update_mode
        self._scale_critic_loss = scale_critic_loss
        self._scale_actor_loss = scale_actor_loss
        self._centralized_critic = centralized_critic
        self._grad_norm_clip = grad_norm_clip
        self._obs_last_action = obs_last_action
        self._critic_obs_last_action = critic_obs_last_action
        self._critic_agent_encoding = critic_agent_encoding
        self._host_agents = host_agents  # The agents using this mixing network
        self._use_cuda = use_cuda

        self._n_updates = 0

        self._state_shape_int = int(np.prod(self.mdp_info.state_space.shape))

        self._mixer = QMixer(
            state_shape=mdp_info.state_space.shape,
            mixing_embed_dim=mixing_embed_dim,
            n_agents=mdp_info.n_agents,
        )
        self._target_mixer = QMixer(
            state_shape=mdp_info.state_space.shape,
            mixing_embed_dim=mixing_embed_dim,
            n_agents=mdp_info.n_agents,
        )

        self.shared_params_bool = self._host_agents[-1]._primary_agent is not None

        if self.shared_params_bool:
            self.actor_params = list(
                host_agents[0].actor_approximator.network.parameters()
            )
            self.critic_params = list(
                chain(
                    host_agents[0].critic_approximator.network.parameters(),
                    self._mixer.parameters(),
                )
            )
        else:
            self.actor_params = list(
                chain(
                    *[
                        agent.actor_approximator.network.parameters()
                        for agent in host_agents
                    ]
                )
            )
            self.critic_params = list(
                chain(
                    *[
                        agent.critic_approximator.network.parameters()
                        for agent in host_agents
                    ],
                    self._mixer.parameters(),
                )
            )

        self._actor_optimizer = actor_optimizer_params["class"](
            self.actor_params, **actor_optimizer_params["params"]
        )
        self._critic_optimizer = critic_optimizer_params["class"](
            self.critic_params, **critic_optimizer_params["params"]
        )

        self.update_target_mixer()

        # Critic obs modification
        # critic_obs_cutoff: variable which determines how much of the actor obs goes into the critic
        self.critic_obs_cutoff_list = []
        for idx_agent, _ in enumerate(self._host_agents):
            if self._critic_obs_last_action and self._critic_agent_encoding:
                critic_obs_cutoff = -1
            elif self._critic_agent_encoding:
                critic_obs_cutoff = (
                    self.mdp_info.observation_space[idx_agent].shape[0]
                    + self.mdp_info.n_agents
                )
            elif self._critic_obs_last_action:
                critic_obs_cutoff = (
                    self.mdp_info.observation_space[idx_agent].shape[0]
                    + self.mdp_info.action_space[idx_agent].n
                )
            else:
                critic_obs_cutoff = self.mdp_info.observation_space[idx_agent].shape[0]
            self.critic_obs_cutoff_list.append(critic_obs_cutoff)

        self._add_save_attr(
            _batch_size="primitive",
            _target_update_frequency="primitive",
            _tau="primitive",
            _warmup_replay_size="primitive",
            _replay_memory="mushroom!",
            _n_updates="primitive",
            _mixer="torch",
            _target_mixer="torch",
            _actor_optimizer="torch",
            _critic_optimizer="torch",
            _use_cuda="primitive",
        )

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            episodes = self._replay_memory.get(self._batch_size)
            max_seq_len = max(len(episode) for episode in episodes)

            # Get global batch information
            (
                batch_states_t,
                batch_rewards_t,
                batch_next_states_t,
                batch_absorbings_t,
                pad_masks_t,
            ) = self.get_mixer_episodes(episodes, max_seq_len)

            # Get agent-specific batch information
            batch_obs_t_list = []
            batch_action_masks_t_list = []
            batch_actions_t_list = []
            batch_next_obs_t_list = []
            batch_next_action_masks_t_list = []
            for idx_agent, _ in enumerate(self._host_agents):
                (
                    batch_obs_t_agent,
                    batch_action_masks_t_agent,
                    batch_actions_t_agent,
                    batch_next_obs_t_agent,
                    batch_next_action_masks_t_agent,
                ) = self.get_agent_episodes(episodes, idx_agent, max_seq_len)
                batch_obs_t_list.append(batch_obs_t_agent)
                batch_action_masks_t_list.append(batch_action_masks_t_agent)
                batch_actions_t_list.append(batch_actions_t_agent)
                batch_next_obs_t_list.append(batch_next_obs_t_agent)
                batch_next_action_masks_t_list.append(batch_next_action_masks_t_agent)

            # Modify obs for critics
            critic_batch_obs_t_list = []
            critic_batch_next_obs_t_list = []
            for idx_agent, _ in enumerate(self._host_agents):
                cutoff = self.critic_obs_cutoff_list[idx_agent]
                critic_batch_obs_t = batch_obs_t_list[idx_agent][:, :, :cutoff]
                critic_batch_next_obs_t = batch_next_obs_t_list[idx_agent][
                    :, :, :cutoff
                ]
                critic_batch_obs_t_list.append(critic_batch_obs_t)
                critic_batch_next_obs_t_list.append(critic_batch_next_obs_t)

            # Get target actions
            target_actions_t_list = []
            for idx_agent, agent in enumerate(self._host_agents):
                target_action = agent._draw_target_action(
                    batch_obs_t_list[idx_agent],
                    batch_next_obs_t_list[idx_agent],
                    batch_next_action_masks_t_list[idx_agent],
                )
                target_actions_t_list.append(target_action)

            # Update critics and mixer
            q_hats = []
            q_nexts = []
            for idx_agent, agent in enumerate(self._host_agents):
                if self._centralized_critic:
                    centralized_actions = torch.cat(batch_actions_t_list, dim=-1)
                    centralized_target_actions = torch.cat(
                        target_actions_t_list, dim=-1
                    )
                    q_hat = agent.critic_approximator.predict(
                        critic_batch_obs_t_list[idx_agent],
                        centralized_actions,
                        output_tensor=True,
                    )
                    q_next = agent.target_critic_approximator.predict(
                        critic_batch_next_obs_t_list[idx_agent],
                        centralized_target_actions,
                        output_tensor=True,
                    )
                else:
                    q_hat = agent.critic_approximator.predict(
                        critic_batch_obs_t_list[idx_agent],
                        batch_actions_t_list[idx_agent],
                        output_tensor=True,
                    )
                    q_next = agent.target_critic_approximator.predict(
                        critic_batch_next_obs_t_list[idx_agent],
                        target_actions_t_list[idx_agent],
                        output_tensor=True,
                    )
                q_hats.append(q_hat)
                q_nexts.append(q_next)

            # Compute mixer predictions
            q_hat = torch.stack(q_hats, dim=-1).unsqueeze(-1)
            q_next = torch.stack(q_nexts, dim=-1).unsqueeze(-1)
            q_tot = self.mix(q_hat, batch_states_t.reshape(-1, self._state_shape_int))
            q_tot_next = self.target_mix(
                q_next, batch_next_states_t.reshape(-1, self._state_shape_int)
            )
            q_tot_target = (
                batch_rewards_t + self.mdp_info.gamma * q_tot_next * ~batch_absorbings_t
            ).detach()

            # Compute critic loss and backpropagate
            q_tot = q_tot * pad_masks_t
            q_tot_target = q_tot_target * pad_masks_t
            critic_loss = F.mse_loss(q_tot, q_tot_target, reduction="sum")
            critic_loss /= pad_masks_t.sum()
            if self._scale_critic_loss:
                critic_loss /= self.mdp_info.n_agents
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self._critic_optimizer.step()

            # Actor update
            actions_update = []
            q_actors = []
            for idx_agent, agent in enumerate(self._host_agents):
                action_update = self.get_actions(
                    idx_agent=idx_agent,
                    batch_obs_t=batch_obs_t_list[idx_agent],
                    batch_action_masks_t=batch_action_masks_t_list[idx_agent],
                )
                actions_update.append(action_update)
            for idx_agent, agent in enumerate(self._host_agents):
                if self._centralized_critic:
                    actions_update_agent = [
                        action if idx == idx_agent else action
                        for idx, action in enumerate(actions_update)
                    ]
                    centralized_action_agent = torch.cat(actions_update_agent, dim=-1)
                    q_actor = agent.critic_approximator.predict(
                        critic_batch_obs_t_list[idx_agent],
                        centralized_action_agent,
                        output_tensor=True,
                    )
                else:
                    q_actor = agent.critic_approximator.predict(
                        critic_batch_obs_t_list[idx_agent],
                        actions_update[idx_agent],
                        output_tensor=True,
                    )
                q_actors.append(q_actor)

            q_actor = torch.stack(q_actors, dim=-1).unsqueeze(-1)
            q_tot_actor = self.mix(
                q_actor, batch_states_t.reshape(-1, self._state_shape_int)
            )
            q_valid = q_tot_actor[pad_masks_t]
            actor_loss = -q_valid.mean()
            if self._scale_actor_loss:
                actor_loss /= self.mdp_info.n_agents
            self._actor_optimizer.zero_grad()
            actor_loss.backward()
            if self._grad_norm_clip is not None:
                actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.actor_params, self._grad_norm_clip
                )
            self._actor_optimizer.step()

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_mixer_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_mixer()
            else:
                raise ValueError(
                    f"Target update mode {self._target_update_mode} not recognised."
                )

            return actor_loss.item(), critic_loss.item()
        else:
            return 0, 0

    def get_actions(self, idx_agent, batch_obs_t, batch_action_masks_t):
        """
        Get actions for the agent with index idx_agent.
        """
        host_agent = self._host_agents[idx_agent]
        batch_logits = host_agent.actor_approximator.predict(
            batch_obs_t,
            None,
            output_all=True,
            output_hidden=False,
            output_tensor=True,
        )
        batch_logits_mask = torch.where(
            batch_action_masks_t,
            batch_logits.squeeze(),
            torch.tensor(float("-inf")),
        )
        actions_probs_gumbel = GumbelSoftmax(
            logits=batch_logits_mask
        ).gumbel_softmax_sample()
        batch_actions_gumbel_hard = (
            torch.max(actions_probs_gumbel, dim=-1, keepdim=True)[0]
            == actions_probs_gumbel
        ).float()
        batch_actions_backprop = (
            batch_actions_gumbel_hard - actions_probs_gumbel
        ).detach() + actions_probs_gumbel

        return batch_actions_backprop

    def get_mixer_episodes(self, episodes, max_seq_len):
        """
        Global batch information for the mixer and agents.
        """
        (
            batch_states,
            batch_rewards,
            batch_next_states,
            batch_absorbings,
            pad_masks,
        ) = ([], [], [], [], [])
        for episode in episodes:
            seq_len = len(episode)

            state_seq = np.array([sample["state"] for sample in episode])
            rewards_seq = np.array(
                [sample["rewards"][0] for sample in episode]
            )  # all agents have same reward, so we just take agent 0's reward
            next_state_seq = np.array([sample["next_state"] for sample in episode])
            absorbing_seq = np.array([sample["absorbing"] for sample in episode])
            mask = np.concatenate([np.ones(seq_len), np.zeros(max_seq_len - seq_len)])

            # Pad to max_seq_len
            state_pad = np.pad(
                state_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            rewards_pad = np.pad(rewards_seq, (0, max_seq_len - seq_len), "constant")
            next_state_pad = np.pad(
                next_state_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            absorbing_pad = np.pad(
                absorbing_seq,
                (0, max_seq_len - seq_len),
                "constant",
                constant_values=1,
            )

            # Append to the batch
            batch_states.append(state_pad)
            batch_rewards.append(rewards_pad)
            batch_next_states.append(next_state_pad)
            batch_absorbings.append(absorbing_pad)
            pad_masks.append(mask)

        # Transpose to [seq_len, batch_size, ...] format
        batch_states = np.array(batch_states).transpose(
            1, 0, 2
        )  # Shape: [seq_len, batch_size, state_dim]
        batch_rewards = np.array(batch_rewards).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]
        batch_next_states = np.array(batch_next_states).transpose(
            1, 0, 2
        )  # Shape: [seq_len, batch_size, state_dim]
        batch_absorbings = np.array(batch_absorbings).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]
        pad_masks = np.array(pad_masks).transpose(
            (1, 0)
        )  # Shape: [seq_len, batch_size]

        # Convert to torch tensors
        batch_states_t = torch.tensor(batch_states, dtype=torch.float32)
        batch_rewards_t = torch.tensor(batch_rewards, dtype=torch.float32).unsqueeze(-1)
        batch_next_states_t = torch.tensor(batch_next_states, dtype=torch.float32)
        batch_absorbings_t = torch.tensor(batch_absorbings, dtype=torch.bool).unsqueeze(
            -1
        )
        pad_masks_t = torch.tensor(pad_masks, dtype=torch.bool).unsqueeze(-1)

        return (
            batch_states_t,
            batch_rewards_t,
            batch_next_states_t,
            batch_absorbings_t,
            pad_masks_t,
        )

    def get_agent_episodes(self, episodes, idx_agent, max_seq_len):
        """
        Agent-specific episodes for the agent with index idx_agent.
        """
        # prepare data arrays
        (
            batch_obs,
            batch_action_masks,
            batch_actions,
            batch_next_obs,
            batch_next_action_masks,
        ) = ([], [], [], [], [])

        for episode in episodes:
            seq_len = len(episode)

            # Prepare the data arrays
            obs_seq = np.array([sample["obs"][idx_agent] for sample in episode])
            action_mask_seq = np.array(
                [sample["action_masks"][idx_agent] for sample in episode]
            )
            actions_seq = np.array([sample["actions"][idx_agent] for sample in episode])
            actions_seq_one_hot = np.eye(self.mdp_info.action_space[self._idx_agent].n)[
                actions_seq
            ].squeeze(1)
            next_obs_seq = np.array(
                [sample["next_obs"][idx_agent] for sample in episode]
            )
            next_action_masks_seq = np.array(
                [sample["next_action_masks"][idx_agent] for sample in episode]
            )

            # Pad to max_seq_len
            obs_pad = np.pad(obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant")
            action_masks_pad = np.pad(
                action_mask_seq,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
                constant_values=1,
            )
            actions_one_hot_pad = np.pad(
                actions_seq_one_hot,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
            )
            next_obs_pad = np.pad(
                next_obs_seq, ((0, max_seq_len - seq_len), (0, 0)), "constant"
            )
            next_action_masks_pad = np.pad(
                next_action_masks_seq,
                ((0, max_seq_len - seq_len), (0, 0)),
                "constant",
                constant_values=1,
            )

            # Append to the batch
            batch_obs.append(obs_pad)
            batch_action_masks.append(action_masks_pad)
            batch_actions.append(actions_one_hot_pad)
            batch_next_obs.append(next_obs_pad)
            batch_next_action_masks.append(next_action_masks_pad)

        # Converts lists to numpy arrays with shape [seq_len, batch_size, ...]
        batch_obs = np.array(batch_obs).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, obs_dim]
        batch_action_masks = np.array(batch_action_masks).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, action_dim]
        batch_actions = np.array(batch_actions).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, action_dim]
        batch_next_obs = np.array(batch_next_obs).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, obs_dim]
        batch_next_action_masks = np.array(batch_next_action_masks).transpose(
            (1, 0, 2)
        )  # Shape: [seq_len, batch_size, action_dim]

        # Convert to torch tensors
        batch_obs_t = torch.tensor(batch_obs, dtype=torch.float32)
        batch_action_masks_t = torch.tensor(batch_action_masks, dtype=torch.bool)
        batch_actions_t = torch.tensor(batch_actions, dtype=torch.long)
        batch_next_obs_t = torch.tensor(batch_next_obs, dtype=torch.float32)
        batch_next_action_masks_t = torch.tensor(
            batch_next_action_masks, dtype=torch.bool
        )

        return (
            batch_obs_t,
            batch_action_masks_t,
            batch_actions_t,
            batch_next_obs_t,
            batch_next_action_masks_t,
        )

    def mix(self, chosen_action_value, state):
        return self._mixer(chosen_action_value, state)

    def target_mix(self, chosen_action_value, state):
        return self._target_mixer(chosen_action_value, state)

    def update_target_mixer(self):
        w = get_weights(self._mixer.parameters())
        set_weights(self._target_mixer.parameters(), w, use_cuda=self._use_cuda)

    def update_target_mixer_soft(self):
        weights = self._tau * self.get_mixer_weights()
        weights += (1 - self._tau) * get_weights(self._target_mixer.parameters())
        set_weights(self._target_mixer.parameters(), weights, use_cuda=self._use_cuda)

    def get_mixer_weights(self):
        return get_weights(self._mixer.parameters())

    def actor_param(self):
        params = torch.cat([param.flatten() for param in self.actor_params])
        return params

    def actor_gradient(self):
        gradient = torch.cat(
            [
                param.grad.view(-1)
                for param in self.actor_params
                if param.grad is not None
            ]
        )
        return gradient

    def actor_grad_norm(self):
        total_norm = 0.0
        for layer in self.actor_params:
            layer_norm = layer.grad.data.norm(2)
            total_norm += layer_norm.item() ** 2
        total_norm = total_norm**0.5
        return total_norm

    def critic_param(self):
        params = torch.cat([param.flatten() for param in self.critic_params])
        return params

    def critic_gradient(self):
        gradient = torch.cat(
            [
                param.grad.view(-1)
                for param in self.critic_params
                if param.grad is not None
            ]
        )
        return gradient

    def critic_grad_norm(self):
        total_norm = 0.0
        for layer in self.critic_params:
            layer_norm = layer.grad.data.norm(2)
            total_norm += layer_norm.item() ** 2
        total_norm = total_norm**0.5
        return total_norm
