import math
import torch
from torch import nn
from typing import Optional


def inject(module, module_init, *args, **kwargs):
    if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
        if torch.numel(module.weight) == 0:
            return None
        return ParameterInjector(module, module_init, *args, **kwargs)

    return None


# In-place modification that injects noise perturbation functions to `net`
def inject_net(
    net,
    net_init=None,
    depth=0,
    layers=[],
    perturb_power=0.018,
    *args,
    **kwargs
):

    if depth == 0:
        layers = []

    if depth > 100:
        return False

    for name, child in net.named_children():
        child_init = (
            net_init.get_submodule(name) if net_init is not None else None
        )
        new_module: Optional[nn.Module] = inject(
            child, child_init, *args, **kwargs
        )

        if new_module is not None:

            new_module.train(mode=child.training)
            net.add_module(name, new_module)

            # Maintain a list of all injected layers, with their order in `net.named_children()`
            layers.append(new_module)

        # Do it recursively
        inject_net(child, child_init, depth + 1, layers, *args, **kwargs)

    if depth == 0:
        # Apply layer-wise scaling
        layer_norms = []
        param_norm = torch.Tensor([l.get_param_norm() for l in layers])
        layer_norms = param_norm * perturb_power

        for i, layer in enumerate(layers):
            layer.set_norm(layer_norms[i])

    return layers


def get_states(module):
    states = []
    for each_module in module.modules():
        if isinstance(each_module, ParameterInjector):
            states.append(each_module.get_state())
    return states


def set_states(module, states):
    current_index = 0
    for each_module in module.modules():
        if isinstance(each_module, ParameterInjector):
            each_module.set_norm(**states[current_index])
            current_index = current_index + 1


def set_perturb_norm(module, *args, **kwargs):
    for each_module in module.modules():
        if isinstance(each_module, ParameterInjector):
            each_module.set_norm(*args, **kwargs)


def enable_perturb(module, *args, **kwargs):
    for each_module in module.modules():
        if isinstance(each_module, ParameterInjector):
            each_module.enable(*args, **kwargs)


def disable_perturb(module, *args, **kwargs):
    for each_module in module.modules():
        if isinstance(each_module, ParameterInjector):
            each_module.disable(*args, **kwargs)


def resample_perturb(module, *args, **kwargs):
    for each_module in module.modules():
        if isinstance(each_module, ParameterInjector):
            each_module.sample(*args, **kwargs)


class ParameterInjector(nn.Module):
    """
    noise_norm: float
    noise_pattern: str | 'prop', 'indep', 'inv', 'subtract'
    """

    def __init__(self, moduleToWrap, moduleToWrap_init, *args, **kwargs):
        super().__init__()
        self.module = moduleToWrap
        self.module_init = moduleToWrap_init

        self.noise_norm = 0.1
        if "noise_norm" in kwargs:
            self.noise_norm = kwargs["noise_norm"]

        self.noise_pattern = "indep"  # changed from "prop"
        if "noise_pattern" in kwargs:
            self.noise_pattern = kwargs["noise_pattern"]

        self.noise_norm_ex = 1.0
        if "noise_norm_ex" in kwargs:
            self.noise_norm_ex = kwargs["noise_norm_ex"]

        # weight
        self.weight_inject = torch.zeros_like(self.module.weight)
        self.weight_original = self.module.weight

        # bias
        if self.module.bias is not None:
            self.bias_inject = torch.zeros_like(self.module.bias)
            self.bias_original = self.module.bias

        self.enabled = False

    def get_state(self):
        return {
            "noise_norm": self.noise_norm,
            "noise_pattern": self.noise_pattern,
            "noise_norm_ex": self.noise_norm_ex,
        }

    def set_norm(self, noise_norm, noise_pattern=None, noise_norm_ex=None):
        if noise_norm is not None:
            self.noise_norm = noise_norm

        if noise_pattern is not None:
            self.noise_pattern = noise_pattern

        if noise_norm_ex is not None:
            self.noise_norm_ex = noise_norm_ex

    def enable(self, *args, **kwargs):
        self.enabled = True

    def disable(self, *args, **kwargs):
        self.enabled = False

    def sample(self, sample_gamma=1.0, *args, **kwargs):

        self.weight_inject = torch.randn(
            *self.module.weight.shape, device=self.module.weight.device
        )

        if self.noise_pattern == "indep":
            self.weight_inject *= sample_gamma
            pass

        if self.noise_pattern == "prop":
            self.weight_inject = (
                self.noise_norm
                * self.weight_inject
                * torch.abs(self.module.weight)
            )

        elif self.noise_pattern == "indep":
            self.weight_inject = self.noise_norm * self.weight_inject

        elif self.noise_pattern == "subtract":
            self.weight_inject = (
                self.noise_norm
                - self.noise_norm
                * self.noise_norm_ex
                * torch.abs(self.module.weight)
            ) * self.weight_inject

        elif self.noise_pattern == "prop-deterministic":
            if self.module_init is not None:
                self.weight_inject = (
                    self.noise_norm
                    * self.noise_norm_ex
                    * (self.module.weight - self.module_init.weight)
                )
            else:
                self.weight_inject = (
                    self.noise_norm * self.noise_norm_ex * self.module.weight
                )

    def get_param_norm(self, *args, **kwargs):
        if self.module_init is not None:
            return (self.module.weight - self.module_init.weight).abs().mean()
        else:
            return 1 / math.sqrt(torch.numel(self.module.weight))

    def get_param_true_norm(self, *args, **kwargs):
        return self.module.weight.norm() / math.sqrt(
            torch.numel(self.module.weight)
        )

    def forward(self, x):
        if self.enabled:
            self.module.weight = nn.Parameter(
                self.weight_original + self.weight_inject
            )
            return self.module(x)

        else:
            self.module.weight = self.weight_original
            return self.module(x)
