import torch
import numpy as np
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.actor_critic.facmac import FACMAC
from copy import deepcopy
import torch.optim as optim
from k_level_policy_gradients.src.distributions.gumbel import GumbelSoftmax

# import matplotlib.pyplot as plt

css = torch.nn.functional.cosine_similarity


class KFACMAC(FACMAC):
    """
    Instantiates a FACMAC mixing network and hypernetwork layers.
    """

    def __init__(self, k_level, **kwargs):
        self._k_level = k_level
        super().__init__(**kwargs)

        self._sgd_actor_optimizer = optim.SGD(self.actor_params, lr=0.0025)

        self._add_save_attr(
            _k_level="primitive",
        )

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            # Store initial actor, critic, and mixer parameters and optimiser states
            initial_actor_params = np.copy(
                self._host_agents[0].actor_approximator.get_weights()
            )
            initial_actor_optimizer_state = deepcopy(self._actor_optimizer.state_dict())

            episodes = self._replay_memory.get(self._batch_size)
            max_seq_len = max(len(episode) for episode in episodes)

            # REMOVE: Debugging
            # params = []
            # params.append(self.actor_param())
            # grads = []
            # grad_norms = []

            # Get global batch information
            (
                batch_states_t,
                batch_rewards_t,
                batch_next_states_t,
                batch_absorbings_t,
                pad_masks_t,
            ) = self.get_mixer_episodes(episodes, max_seq_len)

            # Get agent-specific batch information
            batch_obs_t_list = []
            batch_action_masks_t_list = []
            batch_actions_t_list = []
            batch_next_obs_t_list = []
            batch_next_action_masks_t_list = []
            for idx_agent, _ in enumerate(self._host_agents):
                (
                    batch_obs_t_agent,
                    batch_action_masks_t_agent,
                    batch_actions_t_agent,
                    batch_next_obs_t_agent,
                    batch_next_action_masks_t_agent,
                ) = self.get_agent_episodes(episodes, idx_agent, max_seq_len)
                batch_obs_t_list.append(batch_obs_t_agent)
                batch_action_masks_t_list.append(batch_action_masks_t_agent)
                batch_actions_t_list.append(batch_actions_t_agent)
                batch_next_obs_t_list.append(batch_next_obs_t_agent)
                batch_next_action_masks_t_list.append(batch_next_action_masks_t_agent)

            # Modify obs for critics
            critic_batch_obs_t_list = []
            critic_batch_next_obs_t_list = []
            for idx_agent, _ in enumerate(self._host_agents):
                cutoff = self.critic_obs_cutoff_list[idx_agent]
                critic_batch_obs_t = batch_obs_t_list[idx_agent][:, :, :cutoff]
                critic_batch_next_obs_t = batch_next_obs_t_list[idx_agent][
                    :, :, :cutoff
                ]
                critic_batch_obs_t_list.append(critic_batch_obs_t)
                critic_batch_next_obs_t_list.append(critic_batch_next_obs_t)

            # Get target actions
            target_actions_t_list = []
            for idx_agent, agent in enumerate(self._host_agents):
                target_action = agent._draw_target_action(
                    batch_obs_t_list[idx_agent],
                    batch_next_obs_t_list[idx_agent],
                    batch_next_action_masks_t_list[idx_agent],
                )
                target_actions_t_list.append(target_action)

            # Update critics and mixer
            q_hats = []
            q_nexts = []
            for idx_agent, agent in enumerate(self._host_agents):
                if self._centralized_critic:
                    centralized_actions = torch.cat(batch_actions_t_list, dim=-1)
                    centralized_target_actions = torch.cat(
                        target_actions_t_list, dim=-1
                    )
                    q_hat = agent.critic_approximator.predict(
                        critic_batch_obs_t_list[idx_agent],
                        centralized_actions,
                        output_tensor=True,
                    )
                    q_next = agent.target_critic_approximator.predict(
                        critic_batch_next_obs_t_list[idx_agent],
                        centralized_target_actions,
                        output_tensor=True,
                    )
                else:
                    q_hat = agent.critic_approximator.predict(
                        critic_batch_obs_t_list[idx_agent],
                        batch_actions_t_list[idx_agent],
                        output_tensor=True,
                    )
                    q_next = agent.target_critic_approximator.predict(
                        critic_batch_next_obs_t_list[idx_agent],
                        target_actions_t_list[idx_agent],
                        output_tensor=True,
                    )
                q_hats.append(q_hat)
                q_nexts.append(q_next)

            # Compute mixer predictions
            q_hat = torch.stack(q_hats, dim=-1).unsqueeze(-1)
            q_next = torch.stack(q_nexts, dim=-1).unsqueeze(-1)
            q_tot = self.mix(q_hat, batch_states_t.reshape(-1, self._state_shape_int))
            q_tot_next = self.target_mix(
                q_next, batch_next_states_t.reshape(-1, self._state_shape_int)
            )
            q_tot_target = (
                batch_rewards_t + self.mdp_info.gamma * q_tot_next * ~batch_absorbings_t
            ).detach()

            # Compute critic loss and backpropagate
            q_tot = q_tot * pad_masks_t
            q_tot_target = q_tot_target * pad_masks_t
            critic_loss = F.mse_loss(q_tot, q_tot_target, reduction="sum")
            critic_loss /= pad_masks_t.sum()
            if self._scale_critic_loss:
                critic_loss /= self.mdp_info.n_agents
            self._critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip is not None:
                critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.critic_params, self._grad_norm_clip
                )
            self._critic_optimizer.step()

            # Get level-0 actions
            actions_k = []
            for idx_agent, agent in enumerate(self._host_agents):
                action_k = self.get_actions(
                    idx_agent=idx_agent,
                    batch_obs_t=batch_obs_t_list[idx_agent],
                    batch_action_masks_t=batch_action_masks_t_list[idx_agent],
                ).detach()
                actions_k.append(action_k)

            for k in range(self._k_level):
                actions_update = []
                for idx_agent, agent in enumerate(self._host_agents):
                    action_update = self.get_actions(
                        idx_agent=idx_agent,
                        batch_obs_t=batch_obs_t_list[idx_agent],
                        batch_action_masks_t=batch_action_masks_t_list[idx_agent],
                    )
                    actions_update.append(action_update)

                total_actor_loss = 0
                for idx_agent, agent in enumerate(self._host_agents):
                    if self._centralized_critic:
                        action_mixed = []
                        action_mixed.extend(actions_k[:idx_agent])
                        action_mixed.append(actions_update[idx_agent])
                        action_mixed.extend(actions_k[idx_agent + 1 :])
                        centralized_action_agent = torch.cat(action_mixed, dim=-1)

                        qs = []
                        for idx_agent_q, agent_q in enumerate(self._host_agents):
                            q_update = agent_q.critic_approximator.predict(
                                critic_batch_obs_t_list[idx_agent_q],
                                centralized_action_agent,
                                output_tensor=True,
                            )
                            if idx_agent_q == idx_agent:
                                qs.append(q_update)
                            else:
                                qs.append(q_update)
                    else:
                        qs = []
                        for idx_agent_q, agent_q in enumerate(self._host_agents):
                            if idx_agent_q == idx_agent:
                                q_update = agent_q.critic_approximator.predict(
                                    critic_batch_obs_t_list[idx_agent],
                                    actions_update[idx_agent],
                                    output_tensor=True,
                                )
                                qs.append(q_update)
                            else:
                                q_update = agent_q.critic_approximator.predict(
                                    critic_batch_obs_t_list[idx_agent],
                                    actions_k[idx_agent_q],
                                    output_tensor=True,
                                )
                                qs.append(q_update.detach())
                    q_actor = torch.stack(qs, dim=-1).unsqueeze(-1)
                    q_tot_actor = self.mix(
                        q_actor, batch_states_t.reshape(-1, self._state_shape_int)
                    )
                    q_valid = q_tot_actor[pad_masks_t]
                    agent_actor_loss = -q_valid.mean()
                    total_actor_loss += agent_actor_loss

                # Step actor optimiser on accumulated gradients
                if self._scale_actor_loss:
                    total_actor_loss /= self.mdp_info.n_agents
                self._actor_optimizer.zero_grad()
                # self._sgd_actor_optimizer.zero_grad()
                total_actor_loss.backward()
                if self._grad_norm_clip is not None:
                    actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.actor_params, self._grad_norm_clip
                    )

                # Intermediate k steps use SGD
                # if k < self._k_level - 1:  # Not final k step
                #     self._sgd_actor_optimizer.step()
                # else:
                #     self._actor_optimizer.step()

                # All k-steps use rmsprop
                self._actor_optimizer.step()

                # REMOVE: Debugging
                # params.append(self.actor_param())
                # grads.append(self.actor_gradient())
                # grad_norms.append(self.actor_grad_norm())

                # Compute new level-k actions
                if k < self._k_level - 1:  # Not final k step
                    actions_k = []
                    for idx_agent, agent in enumerate(self._host_agents):
                        action_k = self.get_actions(
                            idx_agent=idx_agent,
                            batch_obs_t=batch_obs_t_list[idx_agent],
                            batch_action_masks_t=batch_action_masks_t_list[idx_agent],
                        ).detach()
                        actions_k.append(action_k)

                    # Reset actor parameters and actor optimiser state
                    self._host_agents[0].actor_approximator.set_weights(
                        np.copy(initial_actor_params)
                    )
                    self._actor_optimizer.load_state_dict(
                        deepcopy(initial_actor_optimizer_state)
                    )

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_mixer_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_mixer()
            else:
                raise ValueError(
                    f"Target update mode {self._target_update_mode} not recognised."
                )

            # REMOVE: DEBUGGING
            # params_distances = []
            # for i in range(len(params) - 1):
            #     params_distances.append(
            #         torch.norm(params[i] - params[i + 1], dim=0).detach().numpy()
            #     )

            # params_final_distances = []
            # for i in range(len(params)):
            #     params_final_distances.append(
            #         torch.norm(params[i] - params[-1], dim=0).detach().numpy()
            #     )

            # params_initial_distances = []
            # for i in range(len(params)):
            #     params_initial_distances.append(
            #         torch.norm(params[i] - params[0], dim=0).detach().numpy()
            #     )

            # grads_css = []
            # for i in range(len(grads) - 1):
            #     grads_css.append(css(grads[i], grads[i + 1], dim=0).detach().numpy())

            # grads_final_css = []
            # for i in range(len(grads)):
            #     grads_final_css.append(css(grads[i], grads[-1], dim=0).detach().numpy())

            # grads_initial_css = []
            # for i in range(len(grads)):
            #     grads_initial_css.append(
            #         css(grads[i], grads[0], dim=0).detach().numpy()
            #     )

            return total_actor_loss.item(), critic_loss.item()
        else:
            return 0, 0
