import torch
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.actor_critic.facmac import FACMAC


class FACMACContinuous(FACMAC):
    """
    Instantiates a FACMAC mixing network and hypernetwork layers with continuous actions and no recurrency.
    """

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            states, obs, actions, rewards, next_states, next_obs, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            # Convert to tensors
            # Use rewards of agent 0, assume all agents have the same reward
            states_t = torch.tensor(states, dtype=torch.float32)
            obs_t = [
                torch.tensor(obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            actions_t = [
                torch.tensor(actions[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(actions))
            ]
            rewards_t = torch.tensor(rewards[:, 0], dtype=torch.float32).unsqueeze(-1)
            next_states_t = torch.tensor(next_states, dtype=torch.float32)
            next_obs_t = [
                torch.tensor(next_obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            absorbing_t = torch.tensor(absorbing, dtype=torch.bool).unsqueeze(-1)

            # Get target actions
            target_actions = []
            for idx_agent, agent in enumerate(self._host_agents):
                target_actions.append(agent._draw_target_action(next_obs_t[idx_agent]))

            # Update critic and mixer
            q_hats = []
            q_nexts = []
            for idx_agent, agent in enumerate(self._host_agents):
                if self._centralized_critic:
                    centralized_actions = torch.cat(actions_t, dim=-1)
                    centralized_target_actions = torch.cat(target_actions, dim=-1)
                    q_hat = agent.critic_approximator.predict(
                        obs_t[idx_agent], centralized_actions, output_tensor=True
                    )
                    q_next = agent.target_critic_approximator.predict(
                        next_obs_t[idx_agent],
                        centralized_target_actions,
                        output_tensor=True,
                    )
                else:
                    q_hat = agent.critic_approximator.predict(
                        obs_t[idx_agent], actions_t[idx_agent], output_tensor=True
                    )
                    q_next = agent.target_critic_approximator.predict(
                        next_obs_t[idx_agent],
                        target_actions[idx_agent],
                        output_tensor=True,
                    )

                q_hats.append(q_hat)
                q_nexts.append(q_next)

            q_hat = torch.stack(q_hats, dim=-1).unsqueeze(-1)
            q_next = torch.stack(q_nexts, dim=-1).unsqueeze(-1)
            q_tot = self.mix(q_hat, states_t.reshape(-1, self._state_shape_int))
            q_tot_next = self.target_mix(
                q_next, next_states_t.reshape(-1, self._state_shape_int)
            )
            q_tot_target = (
                rewards_t + self.mdp_info.gamma * q_tot_next * ~absorbing_t
            ).detach()

            # Compute critic loss and backpropagate
            critic_loss = F.mse_loss(q_tot, q_tot_target)
            if self._scale_critic_loss:
                critic_loss /= self.mdp_info.n_agents
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self._critic_optimizer.step()

            # Actor update
            actions_update = []
            q_actors = []
            for idx_agent, agent in enumerate(self._host_agents):
                action_update = agent.actor_approximator.predict(
                    obs_t[idx_agent], output_tensor=True
                )
                actions_update.append(action_update)
            for idx_agent, agent in enumerate(self._host_agents):
                if self._centralized_critic:
                    actions_update_agent = [
                        action if idx == idx_agent else action.detach()
                        for idx, action in enumerate(actions_update)
                    ]
                    centralized_action_agent = torch.cat(actions_update_agent, dim=-1)
                    q_actor = agent.critic_approximator.predict(
                        obs_t[idx_agent], centralized_action_agent, output_tensor=True
                    )
                else:
                    q_actor = agent.critic_approximator.predict(
                        obs_t[idx_agent],
                        actions_update[idx_agent],
                        output_tensor=True,
                    )
                q_actors.append(q_actor)
            q_actor = torch.stack(q_actors, dim=-1).unsqueeze(-1)
            q_tot_actor = self.mix(q_actor, states_t.reshape(-1, self._state_shape_int))
            actor_loss = -q_tot_actor.mean()
            if self._scale_actor_loss:
                actor_loss /= self.mdp_info.n_agents
            self._actor_optimizer.zero_grad()
            actor_loss.backward()
            if self._grad_norm_clip is not None:
                actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.actor_params, self._grad_norm_clip
                )
            self._actor_optimizer.step()

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_mixer_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_mixer()
            else:
                raise ValueError(
                    f"Target update mode {self._target_update_mode} not recognised."
                )

            return actor_loss.item(), critic_loss.item()
        else:
            return 0, 0
