import numpy as np
import torch
import torch.optim as optim
from k_level_policy_gradients.src.algorithms.actor_critic.maddpg_discrete import (
    MADDPGDiscrete,
)
import torch.nn.functional as F
from copy import deepcopy


class KMADDPG(MADDPGDiscrete):
    def __init__(self, k_level, **kwargs):
        self._k_level = k_level
        super().__init__(**kwargs)

        self._sgd_actor_optimizer = optim.SGD(self.actor_params, lr=0.0001)

        self._add_save_attr(
            _k_level="primitive",
        )

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            # Store initial actor, critic, and mixer parameters and optimiser states
            if self.shared_params_bool:
                initial_actor_params = np.copy(
                    self._host_agents[0].actor_approximator.get_weights()
                )
            else:
                initial_actor_params = [
                    np.copy(agent.actor_approximator.get_weights())
                    for agent in self._host_agents
                ]
            initial_actor_optimizer_state = deepcopy(self._actor_optimizer.state_dict())

            states, obs, actions, rewards, next_states, next_obs, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            # Convert to tensors
            # Use rewards of agent 0, assume all agents have the same reward
            states_t = torch.tensor(states, dtype=torch.float32)
            obs_t = [
                torch.tensor(obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            actions_t = [
                torch.tensor(actions[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(actions))
            ]
            rewards_t = torch.tensor(rewards[:, 0], dtype=torch.float32)
            next_states_t = torch.tensor(next_states, dtype=torch.float32)
            next_obs_t = [
                torch.tensor(next_obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            absorbing_t = torch.tensor(absorbing, dtype=torch.bool)

            # Update critic
            actions_cat = torch.cat(actions_t, dim=-1)
            next_actions_t = [
                self._host_agents[idx_agent].target_actor_approximator.predict(
                    next_obs_t[idx_agent], output_tensor=True
                )
                for idx_agent in range(self.mdp_info.n_agents)
            ]
            next_actions_cat = torch.cat(next_actions_t, dim=-1)
            q_hat = self.critic_approximator.predict(
                states_t, actions_cat, output_tensor=True
            )
            q_next = self.target_critic_approximator.predict(
                next_states_t, next_actions_cat, output_tensor=True
            )
            q_target = (
                rewards_t + self.mdp_info.gamma * q_next * ~absorbing_t
            ).detach()
            critic_loss = F.mse_loss(q_hat, q_target)
            if self._scale_critic_loss:
                critic_loss /= self.mdp_info.n_agents
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self._critic_optimizer.step()

            # DEBUGGING
            if self._n_updates > 0:

                # Get level-0 actions
                actions_k = []
                for idx_agent, agent in enumerate(self._host_agents):
                    action_k = agent.actor_approximator.predict(
                        obs_t[idx_agent], output_tensor=True
                    ).detach()
                    actions_k.append(action_k)

                for k in range(self._k_level):
                    actions_update = []  # actions from the level-0 agents for backprop
                    for idx_agent, agent in enumerate(self._host_agents):
                        action_update = agent.actor_approximator.predict(
                            obs_t[idx_agent], output_tensor=True
                        )
                        actions_update.append(action_update)

                    # Compute actor losses
                    total_actor_loss = 0
                    for idx_agent, agent in enumerate(self._host_agents):
                        action_mixed = []
                        action_mixed.extend(actions_k[:idx_agent])
                        action_mixed.append(actions_update[idx_agent])
                        action_mixed.extend(actions_k[idx_agent + 1 :])
                        action_mixed = torch.cat(action_mixed, dim=-1)
                        q_actor = self.critic_approximator.predict(
                            states_t, action_mixed, output_tensor=True
                        )
                        actor_loss = -q_actor.mean()
                        total_actor_loss += actor_loss

                    # Step actor optimiser on accumulated gradients
                    if self._scale_actor_loss:
                        actor_loss /= self.mdp_info.n_agents
                    self._actor_optimizer.zero_grad()
                    self._sgd_actor_optimizer.zero_grad()
                    total_actor_loss.backward()
                    if self._grad_norm_clip is not None:
                        actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                            self.actor_params, self._grad_norm_clip
                        )
                    self._actor_optimizer.step()

                    # Compute new level-k actions
                    if k < self._k_level - 1:  # Not final k step
                        # Get level-0 actions
                        actions_k = []
                        for idx_agent, agent in enumerate(self._host_agents):
                            action_k = agent.actor_approximator.predict(
                                obs_t[idx_agent], output_tensor=True
                            ).detach()
                            actions_k.append(action_k)

                        # Reset actor parameters and actor optimiser state
                        if self.shared_params_bool:
                            self._host_agents[0].actor_approximator.set_weights(
                                np.copy(initial_actor_params)
                            )
                        else:
                            for idx_agent, agent in enumerate(self._host_agents):
                                agent.actor_approximator.set_weights(
                                    np.copy(initial_actor_params[idx_agent])
                                )
                        self._actor_optimizer.load_state_dict(
                            deepcopy(initial_actor_optimizer_state)
                        )

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_critic_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_critic_hard()

            return 0, 0  # change back to actor loss, critic loss
        else:
            return 0, 0
