import numpy as np
import torch
from k_level_policy_gradients.src.algorithms.actor_critic.maddpg_discrete import (
    MADDPGDiscrete,
)
import torch.nn.functional as F


class MADDPG(MADDPGDiscrete):
    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            states, obs, actions, rewards, next_states, next_obs, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            # Convert to tensors
            # Use rewards of agent 0, assume all agents have the same reward
            states_t = torch.tensor(states, dtype=torch.float32)
            obs_t = [
                torch.tensor(obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            actions_t = [
                torch.tensor(actions[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(actions))
            ]
            rewards_t = torch.tensor(rewards[:, 0], dtype=torch.float32)
            next_states_t = torch.tensor(next_states, dtype=torch.float32)
            next_obs_t = [
                torch.tensor(next_obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            absorbing_t = torch.tensor(absorbing, dtype=torch.bool)

            # Update critic
            actions_cat = torch.cat(actions_t, dim=-1)
            next_actions_t = [
                self._host_agents[idx_agent].target_actor_approximator.predict(
                    next_obs_t[idx_agent], output_tensor=True
                )
                for idx_agent in range(self.mdp_info.n_agents)
            ]
            next_actions_cat = torch.cat(next_actions_t, dim=-1)
            q_hat = self.critic_approximator.predict(
                states_t, actions_cat, output_tensor=True
            )
            q_next = self.target_critic_approximator.predict(
                next_states_t, next_actions_cat, output_tensor=True
            )
            q_target = (
                rewards_t + self.mdp_info.gamma * q_next * ~absorbing_t
            ).detach()
            critic_loss = F.mse_loss(q_hat, q_target)
            if self._scale_critic_loss:
                critic_loss /= self.mdp_info.n_agents
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self._critic_optimizer.step()

            # DEBUGGING
            if self._n_updates > 100:
                # Update actors
                actions_backprop = [
                    self._host_agents[idx_agent].actor_approximator.predict(
                        obs_t[idx_agent], output_tensor=True
                    )
                    for idx_agent in range(self.mdp_info.n_agents)
                ]
                actions_backprop_cat = torch.cat(actions_backprop, dim=-1)
                q = self.critic_approximator.predict(
                    states_t, actions_backprop_cat, output_tensor=True
                )
                actor_loss = -q.mean()
                if self._scale_actor_loss:
                    actor_loss /= self.mdp_info.n_agents
                self._actor_optimizer.zero_grad()
                actor_loss.backward()
                if self._grad_norm_clip is not None:
                    actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.actor_params, self._grad_norm_clip
                    )
                self._actor_optimizer.step()

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_critic_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_critic_hard()

            return 0, 0
        else:
            return 0, 0

    def optimal_q(self, states, actions_a, actions_b):
        optimal_q = torch.sqrt(
            (states[:, 0] + actions_a[:, 0] - actions_b[:, 0]) ** 2
            + (states[:, 1] + actions_a[:, 1] - actions_b[:, 1]) ** 2
        )
        return optimal_q
