import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
import torch.optim as optim
from k_level_policy_gradients.src.algorithms.actor_critic.facmac import FACMAC


class KFACMACContinuous(FACMAC):
    """
    Instantiates a FACMAC mixing network and hypernetwork layers.
    """

    def __init__(self, k_level, **kwargs):
        self._k_level = k_level
        super().__init__(**kwargs)

        self._sgd_actor_optimizer = optim.SGD(self.actor_params, lr=0.0001)

        self._add_save_attr(
            _k_level="primitive",
        )

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            # Store initial actor, critic, and mixer parameters and optimiser states
            initial_actor_params = np.copy(
                self._host_agents[0].actor_approximator.get_weights()
            )
            initial_actor_optimizer_state = deepcopy(self._actor_optimizer.state_dict())

            states, obs, actions, rewards, next_states, next_obs, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            # Convert to tensors
            # Use rewards of agent 0, assume all agents have the same reward
            states_t = torch.tensor(states, dtype=torch.float32)
            obs_t = [
                torch.tensor(obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            actions_t = [
                torch.tensor(actions[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(actions))
            ]
            rewards_t = torch.tensor(rewards[:, 0], dtype=torch.float32).unsqueeze(-1)
            next_states_t = torch.tensor(next_states, dtype=torch.float32)
            next_obs_t = [
                torch.tensor(next_obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            absorbing_t = torch.tensor(absorbing, dtype=torch.bool).unsqueeze(-1)

            # Get target actions
            target_actions = []
            for idx_agent, agent in enumerate(self._host_agents):
                target_actions.append(agent._draw_target_action(next_obs_t[idx_agent]))

            # Update critic and mixer
            q_hats = []
            q_nexts = []
            for idx_agent, agent in enumerate(self._host_agents):
                if self._centralized_critic:
                    centralized_actions = torch.cat(actions_t, dim=-1)
                    centralized_target_actions = torch.cat(target_actions, dim=-1)
                    q_hat = agent.critic_approximator.predict(
                        obs_t[idx_agent], centralized_actions, output_tensor=True
                    )
                    q_next = agent.target_critic_approximator.predict(
                        next_obs_t[idx_agent],
                        centralized_target_actions,
                        output_tensor=True,
                    )
                else:
                    q_hat = agent.critic_approximator.predict(
                        obs_t[idx_agent], actions_t[idx_agent], output_tensor=True
                    )
                    q_next = agent.target_critic_approximator.predict(
                        next_obs_t[idx_agent],
                        target_actions[idx_agent],
                        output_tensor=True,
                    )

                q_hats.append(q_hat)
                q_nexts.append(q_next)

            q_hat = torch.stack(q_hats, dim=-1).unsqueeze(-1)
            q_next = torch.stack(q_nexts, dim=-1).unsqueeze(-1)
            q_tot = self.mix(q_hat, states_t.reshape(-1, self._state_shape_int))
            q_tot_next = self.target_mix(
                q_next, next_states_t.reshape(-1, self._state_shape_int)
            )
            q_tot_target = (
                rewards_t + self.mdp_info.gamma * q_tot_next * ~absorbing_t
            ).detach()

            # Compute critic loss and backpropagate
            critic_loss = F.mse_loss(q_tot, q_tot_target)
            if self._scale_critic_loss:
                critic_loss /= self.mdp_info.n_agents
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self._critic_optimizer.step()

            # Get level-0 actions
            actions_k = []
            for idx_agent, agent in enumerate(self._host_agents):
                action_k = agent.actor_approximator.predict(
                    obs_t[idx_agent], output_tensor=True
                ).detach()
                actions_k.append(action_k)

            for k in range(self._k_level):
                actions_update = []
                for idx_agent, agent in enumerate(self._host_agents):
                    action_update = agent.actor_approximator.predict(
                        obs_t[idx_agent], output_tensor=True
                    )
                    actions_update.append(action_update)

                total_actor_loss = 0
                for idx_agent, agent in enumerate(self._host_agents):
                    if self._centralized_critic:
                        action_mixed = []
                        action_mixed.extend(actions_k[:idx_agent])
                        action_mixed.append(actions_update[idx_agent])
                        action_mixed.extend(actions_k[idx_agent + 1 :])
                        centralized_action_agent = torch.cat(action_mixed, dim=-1)

                        qs = []
                        for idx_agent_q, agent in enumerate(self._host_agents):
                            q_update = agent.critic_approximator.predict(
                                obs_t[idx_agent_q],
                                centralized_action_agent,
                                output_tensor=True,
                            )
                            if idx_agent_q == idx_agent:
                                qs.append(q_update)
                            else:
                                qs.append(q_update)
                    else:
                        qs = []
                        for idx_agent_q, agent in enumerate(self._host_agents):
                            if idx_agent_q == idx_agent:
                                q_update = agent.critic_approximator.predict(
                                    obs_t[idx_agent],
                                    actions_update[idx_agent],
                                    output_tensor=True,
                                )
                                qs.append(q_update)
                            else:
                                q_update = agent.critic_approximator.predict(
                                    obs_t[idx_agent],
                                    actions_k[idx_agent_q],
                                    output_tensor=True,
                                )
                                qs.append(q_update.detach())
                    q_actor = torch.stack(qs, dim=-1).unsqueeze(-1)
                    q_tot_actor = self.mix(
                        q_actor, states_t.reshape(-1, self._state_shape_int)
                    )
                    actor_loss = -q_tot_actor.mean()
                    total_actor_loss += actor_loss

                # Step actor optimiser on accumulated gradients
                if self._scale_actor_loss:
                    actor_loss /= self.mdp_info.n_agents
                self._actor_optimizer.zero_grad()
                self._sgd_actor_optimizer.zero_grad()
                total_actor_loss.backward()
                if self._grad_norm_clip is not None:
                    actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.actor_params, self._grad_norm_clip
                    )

                # Intermediate k steps use SGD
                # if k < self._k_level - 1:  # Not final k step
                #     self._sgd_actor_optimizer.step()
                # else:
                #     self._actor_optimizer.step()

                # All k-steps use rmsprop
                self._actor_optimizer.step()

                # Compute new level-k actions
                if k < self._k_level - 1:  # Not final k step
                    # Get level-k actions
                    actions_k = []
                    for idx_agent, agent in enumerate(self._host_agents):
                        action_k = agent.actor_approximator.predict(
                            obs_t[idx_agent], output_tensor=True
                        ).detach()
                        actions_k.append(action_k)

                    # Reset actor parameters and actor optimiser state
                    self._host_agents[0].actor_approximator.set_weights(
                        np.copy(initial_actor_params)
                    )
                    self._actor_optimizer.load_state_dict(
                        deepcopy(initial_actor_optimizer_state)
                    )

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_mixer_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_mixer()
            else:
                raise ValueError(
                    f"Target update mode {self._target_update_mode} not recognised."
                )

            return total_actor_loss.item(), critic_loss.item()
        else:
            return 0, 0
