from typing import Union

import math
import torch as pt
from torch import nn
from torch.nn import Parameter

from peagang.models.components.utilities_functions import sn_wrap, l2normalize, l2normalizenonDiff

# An ordinary implementation of Swish function
# from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py
class Swish(nn.Module):
    def forward(self, x):
        return x * pt.sigmoid(x)

# A memory-efficient implementation of Swish function
# from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py
class SwishImplementation(pt.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * pt.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = pt.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


# MemoryEfficientSwish
# from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py
class MESwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)

InEfficientSwish=Swish
Swish=MESwish

class SkipBlock(pt.nn.Module):
    def __init__(self,inner,proj=None) -> None:
        super().__init__()
        self.inner=inner
        if proj is not None:
            self.proj=nn.Linear(*proj)
        else:
            self.proj=None

    def forward(self,x):
        xi=self.inner(x)
        if self.proj:
            xadd=self.proj(x)
        else:
            xadd=x
        return xadd+xi



class PermuteBatchnorm1d(pt.nn.BatchNorm1d):
    """
    Applies Batchnorm to the feature dimensions.
    Tensors arrive in B N F shape and Batchnorm1D applies to 1 dimension=> permute and unpermute
    pytorch_geometric actually has an implementation of this...should double check
    """

    def forward(self, input: pt.Tensor) -> pt.Tensor:
        return super().forward(input.permute(0, 2, 1)).permute(0, 2, 1)


class NodeFeatNorm(pt.nn.Module):
    def __init__(self, feat_dim, mode="instance"):
        super().__init__()
        self.mode = mode
        self.feat_dim = feat_dim
        if mode == "instance":
            self.norm = pt.nn.InstanceNorm1d(feat_dim)

    def forward(self, X):
        # X: B N F
        if self.mode == "instance":
            return self.norm(X)


class FeedForward(nn.Module):
    def __init__(self, dimensions: list(), n_layers: int(), dropout=0.0):
        super().__init__()

        self.n_layers = n_layers - 1
        assert len(dimensions) == n_layers

        self.layers = pt.nn.ModuleList()
        for l_ in range(self.n_layers):
            self.layers.append(pt.nn.Linear(dimensions[l_], dimensions[l_ + 1]))
        self.dropout = pt.nn.Dropout(dropout)

    def forward(self, x, activation=pt.nn.ReLU()):
        for l_ in range(self.n_layers):
            x = self.dropout(activation(self.layers[l_](x)))

        return x


class PointNetBlock(nn.Module):
    def __init__(self, input_feat_dim, output_feat_dim, no_B=False, spectral_norm=None,activation=None):
        super().__init__()
        self.input_feat_dim = input_feat_dim
        output_feat_dim = output_feat_dim
        self.A = nn.Parameter(pt.Tensor(input_feat_dim, output_feat_dim))
        self.cT = nn.Parameter(pt.Tensor(1, output_feat_dim))
        self.activation=activation
        if no_B:
            self.register_parameter("B", None)
        else:
            self.B = nn.Parameter(pt.Tensor(input_feat_dim, output_feat_dim))
        self.reset_parameters()
        self.spectral_norm = spectral_norm
        self._sns = []
        if spectral_norm is not None:
            self._sns.extend(
                [
                    sn_wrap(self, spectral_norm, name="A"),
                    sn_wrap(self, spectral_norm, name="cT"),
                ])
            if not no_B:
                self._sns.append(sn_wrap(self, spectral_norm, name="B"))

            for s in self._sns:
                for n, p in s.named_parameters():
                    self.register_parameter(n.replace("module.", ""), p)

    def reset_parameters(self):
        # following https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
        pt.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        fan_in, _ = pt.nn.init._calculate_fan_in_and_fan_out(self.A)
        bound = 1 / math.sqrt(fan_in)
        pt.nn.init.uniform_(self.cT, -bound, bound)
        if self.B is not None:
            pt.nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))

    def forward(self, X):
        """
        Expecting shape B N F or N F
        following eq 8 in https://openreview.net/pdf?id=HkxTwkrKDB
        :param X:
        :return:
        """
        if self.spectral_norm:
            for n in self._sns:
                n._update_u_v()
        xa = X @ self.A
        N = xa.shape[-2]
        o1 = pt.ones(N, 1, device=X.device)
        c = o1 @ self.cT
        if self.B is not None:
            xb = 1 / N * o1 @ o1.t() @ X @ self.B
            out= xa + xb + c
        else:
            out= xa + c
        if self.activation:
            out=self.activation(out)
        return out


class LinearTransmissionLayer(nn.Module):
    def __init__(self, input_feat_dim, output_feat_dim, dropout=None,activation=None, spectral_norm=None):
        super().__init__()
        self.input_feat_dim = input_feat_dim
        self.output_feat_dim = output_feat_dim
        self.B = nn.Parameter(pt.Tensor(self.input_feat_dim, self.output_feat_dim))
        self.cT = nn.Parameter(pt.Tensor(1, self.output_feat_dim))
        self.reset_parameters()
        if dropout:
            self.dropout = pt.nn.Dropout(dropout)
        else:
            self.dropout = None
        self.activation=activation
        self.spectral_norm = spectral_norm
        self._sns = []
        if spectral_norm is not None:
            self._sns.extend(
                [
                    sn_wrap(self, spectral_norm, name="B"),
                    sn_wrap(self, spectral_norm, name="cT"),
                ]
            )
            for s in self._sns:
                for n, p in s.named_parameters():
                    self.register_parameter(n.replace("module.", ""), p)

    def reset_parameters(self):
        # following https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
        pt.nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))
        fan_in, _ = pt.nn.init._calculate_fan_in_and_fan_out(self.B)
        bound = 1 / math.sqrt(fan_in)
        pt.nn.init.uniform_(self.cT, -bound, bound)

    def forward(self, X):
        """
        Expecting shape B N F or N F
        following eq 8 in https://openreview.net/pdf?id=HkxTwkrKDB
        :param X:
        :return:
        """
        if self.dropout:
            X = self.dropout(X)
        N = X.shape[-2]
        o1 = pt.ones(N, 1, device=X.device)
        c = o1 @ self.cT
        xb = 1 / N * (o1 @ o1.t()) @ X @ self.B
        out= xb + c
        if self.activation:
            out=self.activation(out)
        return out

    def extra_repr(self) -> str:
        return f"in_features={self.B.shape[0]},out_features={self.B.shape[1]}"


class SpectralNorm(nn.Module):
    """
    # from https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
    Differentiable version?
    """

    def __init__(self, module, name="weight", power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(pt.mv(pt.t(w.view(height, -1).data), u.data))
            u.data = l2normalize(pt.mv(w.view(height, -1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


class SpectralNormNonDiff(nn.Module):
    """
    # from https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
    """
    def __init__(self, module, name="weight", power_iterations=1):
        super(SpectralNormNonDiff, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations

    def _update_u_v(self):
        if not self._made_params():
            self._make_params()
        w = getattr(self.module, self.name)
        u = getattr(self.module, self.name + "_u")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v = l2normalize(pt.mv(pt.t(w.view(height, -1).data), u))
            u = l2normalize(pt.mv(w.view(height, -1).data, v))

        setattr(self.module, self.name + "_u", u)
        w.data = w.data / pt.dot(u, pt.mv(w.view(height, -1).data, v))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = l2normalize(w.data.new(height).normal_(0, 1))

        self.module.register_buffer(self.name + "_u", u)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


class ConcatAggregate(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return pt.cat(x, dim=self.dim)


class DenseSequential(nn.Module):
    """

    """

    def __init__(self, layers: Union[nn.ModuleList, list, tuple], aggregate=None):
        super().__init__()
        if type(layers) is not nn.ModuleList:
            layers = nn.ModuleList(*layers)
        self.layers = layers
        if aggregate is None:
            aggregate = ConcatAggregate()
        self.agg = aggregate

    def forward(self, x):
        outs = [x]
        for l in self.layers:
            agg = self.agg(outs)
            o = l(agg)
            outs.append(o)
        return outs[-1]
