from typing import (
    List,
    Tuple,
)

import torch
from transfer.sac.utils import get_last_q_layer_name


class RegularizationHelper:
    def __init__(self, ac, actor_memory_weight, critic_memory_weight, l2_mode=False):
        """Class for regularization methods.

        Args:
          cl_reg_coef: Regularization strength for continual learning methods.
            Valid for 'l2', 'ewc', 'mas' continual learning methods.
          regularize_critic: If True, both actor and critic are regularized; if False, only actor
            is regularized.
        """
        self.actor_memory_weight = actor_memory_weight
        self.critic_memory_weight = critic_memory_weight
        self.ac = ac

        self.old_params = {k: v.clone().detach() for k, v in ac.named_parameters()}
        self.importance = None
        self.l2_mode = l2_mode

    def regularization_loss(self) -> torch.Tensor:
        aux_loss = self._regularize(self.old_params)

        return aux_loss

    def compute_importance():
        raise NotImplementedError

    def _regularize(self, old_params: List[torch.Tensor]) -> torch.Tensor:
        """Calculate the regularization loss based on previous parameters and parameter weights."""
        reg_loss = torch.zeros([])
        for name, param in self.ac.named_parameters():
            # Don't regularize the alpha params
            if name not in self.importance.keys():
                # print(f"{name} not in importance")
                continue
            diffs = (param - self.old_params[name]) ** 2
            weighted_diffs = self.importance[name] * diffs
            reg_loss += torch.sum(weighted_diffs)
        return reg_loss


class EWCHelper(RegularizationHelper):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def get_grads(
        self,
        obs: torch.Tensor,
        next_obs: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        done: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        # Main outputs from computation graph
        pi, logp_pi, mu, log_std = self.ac.pi(obs, return_dist=True)
        std = log_std.exp()

        q1_vals = self.ac.q1(obs, actions)
        q2_vals = self.ac.q2(obs, actions)

        grad_outputs = torch.nn.functional.one_hot(torch.arange(mu.numel()), num_classes=mu.numel())
        grad_outputs = grad_outputs.view([mu.numel()] + list(mu.shape))

        mu_gs = torch.autograd.grad(
            [mu],
            self.ac.pi.parameters(),
            grad_outputs=grad_outputs,
            is_grads_batched=True,
            allow_unused=True,
            retain_graph=True,
        )
        mu_gs = {
            key: None if mu_g is None else mu_g.view(list(mu.shape) + list(mu_g.shape[1:]))
            for (key, _), mu_g in zip(self.ac.pi.named_parameters(), mu_gs)
        }
        std_gs = torch.autograd.grad(
            [std],
            self.ac.pi.parameters(),
            grad_outputs=grad_outputs,
            is_grads_batched=True,
            allow_unused=True,
            retain_graph=True,
        )
        std_gs = {
            key: None if std_g is None else std_g.view(list(std.shape) + list(std_g.shape[1:]))
            for (key, _), std_g in zip(self.ac.pi.named_parameters(), std_gs)
        }

        grad_outputs = torch.nn.functional.one_hot(torch.arange(q1_vals.numel()), num_classes=q1_vals.numel())
        grad_outputs = grad_outputs.view([q1_vals.numel()] + list(q1_vals.shape))
        q1_gs = torch.autograd.grad(
            [q1_vals], self.ac.q1.parameters(), grad_outputs=grad_outputs, is_grads_batched=True, retain_graph=True
        )
        q1_gs = {
            key: None if q1_g is None else q1_g.view(list(q1_vals.shape) + list(q1_g.shape[1:]))
            for (key, _), q1_g in zip(self.ac.q1.named_parameters(), q1_gs)
        }

        q2_gs = torch.autograd.grad(
            [q2_vals], self.ac.q2.parameters(), grad_outputs=grad_outputs, is_grads_batched=True, retain_graph=False
        )
        q2_gs = {
            key: None if q2_g is None else q2_g.view(list(q2_vals.shape) + list(q2_g.shape[1:]))
            for (key, _), q2_g in zip(self.ac.q2.named_parameters(), q2_gs)
        }
        return mu_gs, std_gs, q1_gs, q2_gs, std.detach()


    @torch.no_grad()
    def compute_importance(self, grads) -> List[torch.Tensor]:
        actor_mu_gs, actor_std_gs, q1_gs, q2_gs, std = grads

        last_q_layer_name = get_last_q_layer_name(self.ac)
        reg_weights = {}
        for param_name, _ in self.ac.pi.named_parameters():
            # Do not regularize the last layer
            if "mu_layer" in param_name or "std_layer" in param_name:
                continue
            mu_g = actor_mu_gs[param_name]
            std_g = actor_std_gs[param_name]

            # if mu_g is not None:
            #     print(param_name, "mu_g shape", mu_g.shape)
            # if std_g is not None:
            #     print(param_name, "std_g shape", std_g.shape)

            if mu_g is None and std_g is None:
                raise ValueError("Both mu and std gradients are None!")
            if mu_g is None:
                mu_g = torch.zeros_like(std_g)
            if std_g is None:
                std_g = torch.zeros_like(mu_g)

            # Broadcasting std for every parameter in the model
            dims_to_add = int(mu_g.ndim - std.ndim)
            broad_shape = list(std.shape) + [1] * dims_to_add
            broad_std = std.view(broad_shape)  # broadcasting
            # print(param_name, "std", broad_std.shape)

            # Fisher information, see the derivation
            fisher = 1 / (broad_std**2 + 1e-6) * (mu_g**2 + 2 * std_g**2)

            # Sum over the output dimensions
            fisher = fisher.sum(1)

            # Clip from below
            fisher = torch.clip(fisher, min=1e-5)

            # Average over the examples in the batch
            if self.l2_mode:
                reg_weights["pi." + param_name] = self.actor_memory_weight * torch.ones_like(fisher.mean(0))
            else:
                reg_weights["pi." + param_name] = self.actor_memory_weight * fisher.mean(0)

        for param_name, _ in self.ac.q1.named_parameters():
            # Do not regularize the last layer.
            if last_q_layer_name in param_name:
                continue
            q_g = q1_gs[param_name]
            fisher = q_g**2
            if self.l2_mode:
                reg_weights["q1." + param_name] = self.critic_memory_weight * torch.ones_like(fisher.mean(0))
            else:
                reg_weights["q1." + param_name] = self.critic_memory_weight * fisher.mean(0)

        for param_name, _ in self.ac.q2.named_parameters():
            # Do not regularize the last layer.
            if last_q_layer_name in param_name:
                continue
            q_g = q2_gs[param_name]
            fisher = q_g**2
            if self.l2_mode:
                reg_weights["q2." + param_name] = self.critic_memory_weight * torch.ones_like(fisher.mean(0))
            else:
                reg_weights["q2." + param_name] = self.critic_memory_weight * fisher.mean(0)

        importance = {k: v.detach() for k, v in reg_weights.items()}
        return importance
