import torch
import numpy as np
import torch.nn.functional as F
from k_level_policy_gradients.src.algorithms.agent import Agent
from k_level_policy_gradients.src.mixers.mixers import QMixer
from k_level_policy_gradients.src.utils.torch import get_weights, set_weights
from itertools import chain


class COMIX(Agent):
    """
    Instantiates a COMIX mixing network and hypernetwork layers with continuous actions and no recurrency.
    """

    def __init__(
        self,
        mdp_info,
        idx_agent,
        batch_size,
        replay_memory,
        target_update_frequency,
        tau,
        warmup_replay_size,
        target_update_mode,
        mixing_embed_dim,
        optimizer_params,
        scale_loss,
        grad_norm_clip,
        obs_last_action,
        host_agents,
        use_cuda=False,
    ):
        super().__init__(mdp_info, policy=None, idx_agent=idx_agent)

        self._batch_size = batch_size
        self._replay_memory = replay_memory
        self._target_update_frequency = target_update_frequency
        self._tau = tau
        self._warmup_replay_size = warmup_replay_size
        self._target_update_mode = target_update_mode
        self._scale_loss = scale_loss
        self._grad_norm_clip = grad_norm_clip
        self._obs_last_action = obs_last_action
        self._host_agents = host_agents  # The agents using this mixing network
        self._use_cuda = use_cuda

        self._n_updates = 0

        self._state_shape_int = int(np.prod(self.mdp_info.state_space.shape))

        self._mixer = QMixer(
            state_shape=mdp_info.state_space.shape,
            mixing_embed_dim=mixing_embed_dim,
            n_agents=mdp_info.n_agents,
        )
        self._target_mixer = QMixer(
            state_shape=mdp_info.state_space.shape,
            mixing_embed_dim=mixing_embed_dim,
            n_agents=mdp_info.n_agents,
        )

        self.shared_params_bool = self._host_agents[-1]._primary_agent is not None

        # Assume agents share network parameters (therefore only need primary agent's network)
        # Use a single optimizer for the mixer and shared agent network
        self.params = list(
            chain(
                host_agents[0].approximator.network.parameters(),
                self._mixer.parameters(),
            )
        )
        self._optimizer = optimizer_params["class"](
            self.params, **optimizer_params["params"]
        )

        self.update_target_mixer()

        self._add_save_attr(
            _batch_size="primitive",
            _target_update_frequency="primitive",
            _tau="primitive",
            _warmup_replay_size="primitive",
            _replay_memory="mushroom!",
            _n_updates="primitive",
            _mixer="torch",
            _target_mixer="torch",
            _optimizer="torch",
            _use_cuda="primitive",
        )

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.size > self._warmup_replay_size:
            states, obs, actions, rewards, next_states, next_obs, absorbing, _ = (
                self._replay_memory.get(self._batch_size)
            )

            # Convert to tensors
            # Use rewards of agent 0, assume all agents have the same reward
            states_t = torch.tensor(states, dtype=torch.float32)
            obs_t = [
                torch.tensor(obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            actions_t = [
                torch.tensor(actions[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(actions))
            ]
            rewards_t = torch.tensor(rewards[:, 0], dtype=torch.float32).unsqueeze(-1)
            next_states_t = torch.tensor(next_states, dtype=torch.float32)
            next_obs_t = [
                torch.tensor(next_obs[idx_agent], dtype=torch.float32)
                for idx_agent in range(len(obs))
            ]
            absorbing_t = torch.tensor(absorbing, dtype=torch.bool).unsqueeze(-1)

            # Buffers for calculating q_hats, q_nexts, actor actions for updates
            q_hats = []
            q_nexts = []
            for idx_agent, agent in enumerate(self._host_agents):
                q_hat = agent.approximator.predict(
                    obs_t[idx_agent], actions_t[idx_agent], output_tensor=True
                )

                # target qs
                next_obs = next_obs_t[idx_agent]
                N = agent.policy._N
                Ne = agent.policy._Ne
                max_its = agent.policy._max_its
                action_dim = self.mdp_info.action_space[idx_agent].shape[0]

                mu = torch.zeros((self._batch_size, action_dim), dtype=torch.float32)
                std = torch.torch.ones(
                    (self._batch_size, action_dim),
                    dtype=torch.float32,
                )
                for it in range(max_its):
                    dist = torch.distributions.Normal(mu, std)
                    actions = dist.sample((N,)).detach().permute(1, 0, 2)
                    actions_prime = torch.tanh(actions)
                    next_obs_expanded = (
                        next_obs.unsqueeze(1).expand(-1, N, -1).contiguous()
                    )
                    next_qs = agent.target_approximator.predict(
                        next_obs_expanded, actions_prime, output_tensor=True
                    )
                    topk, topk_idxs = torch.topk(next_qs, Ne, dim=1)
                    mu = torch.mean(
                        actions.gather(
                            1,
                            topk_idxs.repeat(1, 1, action_dim).long(),
                        ),
                        dim=1,
                    )
                    std = torch.std(
                        actions.gather(
                            1,
                            topk_idxs.repeat(1, 1, action_dim).long(),
                        ),
                        dim=1,
                    )
                q_next, _ = torch.topk(next_qs, 1, dim=1)

                q_hats.append(q_hat.squeeze())
                q_nexts.append(q_next.squeeze())

            # Compute mixer predictions
            q_hat = torch.stack(q_hats, dim=-1).unsqueeze(-1)
            q_next = torch.stack(q_nexts, dim=-1).unsqueeze(-1)

            q_tot = self.mix(q_hat, states_t.reshape(-1, self._state_shape_int))
            q_tot_next = self.target_mix(
                q_next, next_states_t.reshape(-1, self._state_shape_int)
            )
            q_tot_target = (
                rewards_t + self.mdp_info.gamma * q_tot_next * ~absorbing_t
            ).detach()

            # Compute critic loss and backpropagate
            loss = F.mse_loss(q_tot, q_tot_target)
            if self._scale_loss:
                loss /= self.mdp_info.n_agents
            self._optimizer.zero_grad()
            loss.backward()
            if self._grad_norm_clip is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.params, self._grad_norm_clip
                )
            self._optimizer.step()

            # Update target mixer
            self._n_updates += 1
            if self._target_update_mode == "soft":
                self.update_target_mixer_soft()
            elif self._target_update_mode == "hard":
                if self._n_updates % self._target_update_frequency == 0:
                    self.update_target_mixer()

            return loss.item(), loss.item()
        else:
            return 0, 0

    def mix(self, chosen_action_value, state):
        return self._mixer(chosen_action_value, state)

    def target_mix(self, chosen_action_value, state):
        return self._target_mixer(chosen_action_value, state)

    def update_target_mixer(self):
        w = get_weights(self._mixer.parameters())
        set_weights(self._target_mixer.parameters(), w, use_cuda=self._use_cuda)

    def update_target_mixer_soft(self):
        weights = self._tau * self.get_mixer_weights()
        weights += (1 - self._tau) * get_weights(self._target_mixer.parameters())
        set_weights(self._target_mixer.parameters(), weights, use_cuda=self._use_cuda)

    def get_mixer_weights(self):
        return get_weights(self._mixer.parameters())

    def compute_grad_norm(self):
        total_norm = 0.0
        for layer in self.params:
            layer_norm = layer.grad.data.norm(2)
            total_norm += layer_norm.item() ** 2
        total_norm = total_norm**0.5
        return total_norm
