"""
# https://raw.githubusercontent.com/vsitzmann/siren/master/modules.py
"""
import torch
from torch import nn
from torchmeta.modules import (MetaModule, MetaSequential)
import numpy as np
from collections import OrderedDict
import math
import torch.nn.functional as F


class BatchLinear(nn.Linear, MetaModule):
    '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
    hypernetwork.'''
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


class Sine(nn.Module):
    def __init__(self, sine_frequency=4.):
        """
        
        :param sine_frequency:
        :return:
        """
        super().__init__()
        self.SINE_FREQUENCY = sine_frequency

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(self.SINE_FREQUENCY * input)


class ScalingLayer(nn.Module):
    """
    Layer that scales the input by multiplying with a scalar.
    """
    def __init__(self, init_value=1.0, trainable=False):
        super().__init__()
        self.scaling_factor = nn.Parameter(torch.tensor(init_value, requires_grad=trainable))

    def forward(self, input):
        return self.scaling_factor * input


class SoftThreshold(nn.Module):
    """
    Layer that implements a logistic function  1./(1+exp(-k * (x-x_0))) to be used as a soft threshold.
    https://en.wikipedia.org/wiki/Logistic_function
    """
    def __init__(self, k=1.0, x0=1., trainable=True):
        """

        :param k: the logistic growth rate or steepness of the curve
        :param x0: the x value of the sigmoid's midpoint
        :param trainable: whether k and x0 are trainable
        """
        super().__init__()
        self.k = nn.Parameter(torch.tensor(float(k), requires_grad=trainable))
        self.x0 = nn.Parameter(torch.tensor(float(x0), requires_grad=trainable))

    def forward(self, input):
        return 1. / (1 + torch.exp(-self.k * (input - self.x0)))

def identity(x):
    return x

class InverseStandardizeLayer(nn.Module):
    """
    Layer that standardizes the output, i.e. substract mean and divide by standard deviation, with an optional
    transformation before doing that.
    Implemented as layer, since we also need the true untransformed units for calculating the fluxes etc,
    but want to train the network on normalized / scaled values.
    """

    def __init__(self, X, activation=None, apply_activation_to=None, trainable=False, no_shift=False):
        """

        :param shift:
        :param scale:
        :param trainable:
        :param device:
        """
        self.dim = X.shape[-1]
        self.apply_activation_to = apply_activation_to if apply_activation_to is not None else np.arange(self.dim)

        self.activation = identity
        self.activation_inv = identity
        self.activation_np = identity

        if activation == "sqrt":
            self.activation = torch.sqrt
            self.activation_np = np.sqrt
            self.activation_inv = torch.square
        elif activation == "log1p":
            self.activation = torch.log1p
            self.activation_np = np.log1p

            self.activation_inv = torch.expm1
        elif activation is not None:
            raise NotImplementedError(f"Unknown activation: {activation}")

        super().__init__()
        # self.device = device
        self.trainable = trainable

        X = self.activation_np(X)

        shift = -1 * np.mean(X, axis=0)
        scale = np.std(X, axis=0)

        if no_shift:
            shift *= 0.
        shift = torch.FloatTensor(shift)#.to(self.device)  # (-1*np.mean(X, axis=0, ))
        scale = torch.FloatTensor(scale)#.to(self.device)  #

        self.shift = nn.Parameter(shift, requires_grad=trainable)
        self.scale = nn.Parameter(scale, requires_grad=trainable)

    def forward(self, input, invert:bool=True, do_np:bool=False):
        if invert:
            transformed = self.activation_np(input * self.scale - self.shift)
            # transformed[..., self.apply_activation_to] = self.activation_inv(transformed[..., self.apply_activation_to])
        else:
            transformed = (self.activation_np(input) + self.shift) / self.scale
            # transformed[..., self.apply_activation_to] = self.activation(transformed[..., self.apply_activation_to])

        return transformed
    def forward_np(self, input, invert:bool=True, do_np:bool=True):
        act = self.activation_np if do_np else self.activation
        if invert:
            transformed = act(input * self.scale - self.shift)
            # transformed[..., self.apply_activation_to] = self.activation_inv(transformed[..., self.apply_activation_to])
        else:
            transformed = (act(input) + self.shift) / self.scale
            # transformed[..., self.apply_activation_to] = self.activation(transformed[..., self.apply_activation_to])

        return transformed


class NormalizeLayer(nn.Module):
    """
    Layer that normalizes the input to the range [-1, 1]; Required for the input of SIREN layers.
    Implemented as layer, because we need the values of the untransformed input domain for calculating derivatives
    with respect to the signal domain.
    """
    def __init__(self, min, max, dim, ignore_dims=None, trainable=False):
        """

        :param shift:
        :param scale:
        :param trainable:
        :param device:
        """
        super().__init__()
        self.trainable = trainable
        self.dim = dim
        self.shift = -1 * np.ones(dim)
        self.shift[ignore_dims] = 0.
        self.shift = nn.Parameter(torch.FloatTensor(self.shift), requires_grad=trainable)

        self.min = min  # np.min(X, axis=0)
        self.max = max  # np.max(X, axis=0)

        self.min[ignore_dims] = 0.
        self.max[ignore_dims] = 2.

        min = torch.FloatTensor(self.min)#.to(self.device)  # (-1*np.mean(X, axis=0, ))
        max = torch.FloatTensor(self.max)#.to(self.device)  #

        self.min = nn.Parameter(min, requires_grad=trainable)
        self.max = nn.Parameter(max, requires_grad=trainable)

    def forward(self, input):
        """
        x \in [xmin, xmax] -> x_norm \in [-1, 1]
        :param input:
        :return:
        """
        return 2 * (input - self.min) / (self.max - self.min) + self.shift

    def _inverse(self, output):
        """
        x \in [-1, +1] -> x_norm \in [xmin, xmax]

        :param output:
        :return:
        """
        unshifted_output = output - self.shift
        return (1. / 2) * (unshifted_output) * (self.max - self.min) + self.min

    @classmethod
    def from_data(cls, X, **kwargs):
        """
        Initialize minimum and maximum values for normalization from a numpy array of data.
        Columns indicate different variables, rows different samples.
        Each column is normalized independently.
        :param X:
        :param kwargs:
        :return:
        """
        return cls(min=np.min(X, 0), max=np.max(X, 0), dim=X.shape[1], **kwargs)


class FCBlock(MetaModule):
    '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
    Can be used just as a normal neural network though, as well.
    '''
    default_sine_freq = 7

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, outermost_positive=False, nonlinearity='relu', weight_init=None,
                 **kwargs
                 ):
        super().__init__()

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine': (Sine(kwargs.get("sine_frequency", FCBlock.default_sine_freq)),
                                  gen_sine_init(kwargs.get("sine_frequency", FCBlock.default_sine_freq)),
                                  first_layer_sine_init),
                         'relu': (nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid': (nn.Sigmoid(), init_weights_xavier, None),
                         'tanh': (nn.Tanh(), init_weights_xavier, None),
                         'selu': (nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus': (nn.Softplus(), init_weights_normal, None),
                         'elu': (nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.net = []
        self.net.append(MetaSequential(
            BatchLinear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, hidden_features), nl  # , nn.BatchNorm1d(hidden_features)
            ))

        if outermost_linear:
            # no last activation
            if not outermost_positive:
                self.net.append(MetaSequential(BatchLinear(hidden_features, out_features)))
            else:
                # additional force output to be positive
                self.net.append(MetaSequential(
                    BatchLinear(hidden_features, out_features), torch.nn.Softplus()  # , ExpLayer()  # ExpLayer()
                ))
        else:
            # last layer has activation
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, out_features), nl
            ))

        self.net = MetaSequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None:  # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords, params=None, **kwargs):
        if params is None:
            params = OrderedDict(self.named_parameters())

        output = self.net(coords, params=self.get_subdict(params, 'net'))
        return output

    def forward_with_activations(self, coords, params=None, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.'''
        if params is None:
            params = OrderedDict(self.named_parameters())

        activations = OrderedDict()

        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            subdict = self.get_subdict(params, 'net.%d' % i)
            for j, sublayer in enumerate(layer):
                if isinstance(sublayer, BatchLinear):
                    x = sublayer(x, params=self.get_subdict(subdict, '%d' % j))
                else:
                    x = sublayer(x)

                if retain_grad:
                    x.retain_grad()
                activations['_'.join((str(sublayer.__class__), "%d" % i))] = x
        return activations

class PosEncodingNeRF(nn.Module):
    '''Module to add positional encoding as in NeRF [Mildenhall et al. 2020].'''

    def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True, donotuse=False):
        super().__init__()

        self.in_features = in_features
        if not donotuse:
            if self.in_features == 3:
                self.num_frequencies = 0
            elif self.in_features == 2:
                assert sidelength is not None
                if isinstance(sidelength, int):
                    sidelength = (sidelength, sidelength)
                self.num_frequencies = 4
                if use_nyquist:
                    self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
            elif self.in_features == 1:
                assert fn_samples is not None
                self.num_frequencies = 4
                if use_nyquist:
                    self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
            elif self.in_features > 3:
                self.num_frequencies = 0
        else:
            self.num_frequencies = 0

        self.sigma = 1e-2
        self.out_dim = in_features + 2 * in_features * self.num_frequencies

    def get_num_frequencies_nyquist(self, samples):
        nyquist_rate = 1 / (2 * (2 * 1 / samples))
        return int(math.floor(math.log(nyquist_rate, 2)))

    def forward(self, coords):
        coords = coords.view(coords.shape[0], -1, self.in_features)
        sigmas = torch.ones((self.num_frequencies, self.in_features)) * self.sigma
        # sigmas = torch.randn((self.num_frequencies, self.in_features)) * self.sigma
        coords_pos_enc = coords
        for i in range(self.num_frequencies):
            for j in range(self.in_features):
                c = coords[..., j]

                sin = torch.unsqueeze(torch.sin((2 ** (sigmas[i, j] * i / self.num_frequencies)) * np.pi * c), -1)
                cos = torch.unsqueeze(torch.cos((2 ** (sigmas[i, j] * i / self.num_frequencies)) * np.pi * c), -1)
                tmp = (coords_pos_enc, sin, cos)
                coords_pos_enc = torch.cat(tmp, -1)

        return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)


class FourierEmbedding(nn.Module):
    def __init__(self, input_dim, num_frequencies=20):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.out_features = self.num_frequencies * 2
        self.f = torch.nn.Linear(in_features=input_dim, out_features=num_frequencies,
                                 bias=False)

    def forward(self, x):
        freq = 2 * np.pi * self.f(x)
        return torch.cat([torch.sin(freq), torch.cos(freq)], -1)


class NVidiaAttention(nn.Module):
    def __init__(self, input_dim, num_layers=3, embedding_dim=40, num_frequencies=20):
        super().__init__()
        self.in_features = input_dim
        self.embedding_dim = embedding_dim
        self.out_features = self.embedding_dim
        self.activation = torch.nn.Softplus()

        self.fourier_embed = FourierEmbedding(input_dim, num_frequencies=num_frequencies)
        self.transform1 = torch.nn.Sequential(torch.nn.Linear(self.fourier_embed.out_features, embedding_dim),
                                              self.activation)
        self.transform2 = torch.nn.Sequential(torch.nn.Linear(self.fourier_embed.out_features, embedding_dim),
                                              self.activation)

        z_nets = [torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim),
                                      self.activation)] + [
                     torch.nn.Sequential(torch.nn.Linear(embedding_dim, embedding_dim),
                                         self.activation)
                     for i in range(num_layers - 1)]
        self.Z_nets = torch.nn.ModuleList(z_nets)

    def forward(self, x):
        fourier_embedding = self.fourier_embed(x)
        t1 = self.transform1(fourier_embedding)
        t2 = self.transform2(fourier_embedding)
        for z_net in self.Z_nets:
            z = z_net(x)
            x = (1 - z) * t1 + z * t2
        return x


class AttentionLayer(nn.Module):
    def __init__(self, input_dim, num_layers=3, embedding_dim=40):
        super().__init__()
        self.in_features = input_dim
        self.embedding_dim = embedding_dim
        self.out_features = embedding_dim
        self.activation = torch.nn.Softplus()

        self.query_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)
        self.key_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)
        self.val_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)

        self.Z_nets = torch.nn.ModuleList(
            torch.nn.Sequential(torch.nn.Linear(embedding_dim, embedding_dim), self.activation)
            for i in range(num_layers))

    def forward(self, x):
        U, H, V = self.key_net(x), self.query_net(x), self.val_net(x)
        for zn in self.Z_nets:
            Z = zn(H)
            H = (1 - Z) * U + Z * V
        return H
from typing import Callable

class FCLayer(torch.nn.Module):
    """
    Fully Connected Block, that also supports sine activations (they need a specific initialization)
    """

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_activation: Callable = None, nonlinearity="tanh", **kwargs):
        super().__init__()

        nls_and_inits = {'sine': (Sine(kwargs.get("sine_frequency", 7)),
                                  gen_sine_init(kwargs.get("sine_frequency", 7)),
                                  first_layer_sine_init),
                         'relu': (nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid': (nn.Sigmoid(), init_weights_xavier, None),
                         'tanh': (nn.Tanh(), init_weights_xavier, None),
                         'selu': (nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus': (nn.Softplus(), init_weights_normal, None),
                         'elu': (nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        self.weight_init = nl_weight_init

        self.net = []
        self.net.append(torch.nn.Sequential(
            torch.nn.Linear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(torch.nn.Sequential(
                torch.nn.Linear(hidden_features, hidden_features), nl  # , nn.BatchNorm1d(hidden_features)
            ))

        if outermost_activation:
            self.net.append(torch.nn.Sequential(
                torch.nn.Linear(hidden_features, out_features), outermost_activation()
            ))
        else:
            self.net.append(torch.nn.Sequential(
                torch.nn.Linear(hidden_features, out_features)
            ))

        self.net = torch.nn.Sequential(*self.net)

        self.net.apply(nl_weight_init)
        if first_layer_init is not None:  # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords):
        return self.net(coords)


class FourierEmbedding(nn.Module):
    def __init__(self, input_dim, num_frequencies=20):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.out_features = self.num_frequencies * 2
        self.f = torch.nn.Linear(in_features=input_dim, out_features=num_frequencies,
                                 bias=False)

    def forward(self, x):
        freq = 2 * np.pi * self.f(x)
        return torch.cat([torch.sin(freq), torch.cos(freq)], -1)


class NVidiaAttention(nn.Module):
    def __init__(self, input_dim, num_layers=3, embedding_dim=40, num_frequencies=20):
        super().__init__()
        self.in_features = input_dim
        self.embedding_dim = embedding_dim
        self.out_features = self.embedding_dim
        self.activation = torch.nn.Softplus()

        self.fourier_embed = FourierEmbedding(input_dim, num_frequencies=num_frequencies)
        self.transform1 = torch.nn.Sequential(torch.nn.Linear(self.fourier_embed.out_features, embedding_dim),
                                              self.activation)
        self.transform2 = torch.nn.Sequential(torch.nn.Linear(self.fourier_embed.out_features, embedding_dim),
                                              self.activation)

        z_nets = [torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim),
                                      self.activation)] + [
                     torch.nn.Sequential(torch.nn.Linear(embedding_dim, embedding_dim),
                                         self.activation)
                     for i in range(num_layers - 1)]
        self.Z_nets = torch.nn.ModuleList(z_nets)

    def forward(self, x):
        fourier_embedding = self.fourier_embed(x)
        t1 = self.transform1(fourier_embedding)
        t2 = self.transform2(fourier_embedding)
        for z_net in self.Z_nets:
            z = z_net(x)
            x = (1 - z) * t1 + z * t2
        return x


class AttentionLayer(nn.Module):
    def __init__(self, input_dim, num_layers=3, embedding_dim=40):
        super().__init__()
        self.in_features = input_dim
        self.embedding_dim = embedding_dim
        self.out_features = embedding_dim
        self.activation = torch.nn.Softplus()

        self.query_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)
        self.key_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)
        self.val_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)

        self.Z_nets = torch.nn.ModuleList(
            torch.nn.Sequential(torch.nn.Linear(embedding_dim, embedding_dim), self.activation)
            for i in range(num_layers))

    def forward(self, x):
        U, H, V = self.key_net(x), self.query_net(x), self.val_net(x)
        for zn in self.Z_nets:
            Z = zn(H)
            H = (1 - Z) * U + Z * V
        return H

########################
# Initialization methods
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    # grab from upstream pytorch branch and paste here for now
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def init_weights_trunc_normal(m):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            fan_in = m.weight.size(1)
            fan_out = m.weight.size(0)
            std = math.sqrt(2.0 / float(fan_in + fan_out))
            mean = 0.
            # initialize with the same behavior as tf.truncated_normal
            # "The generated values follow a normal distribution with specified mean and
            # standard deviation, except that values whose magnitude is more than 2
            # standard deviations from the mean are dropped and re-picked."
            _no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)


def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def gen_sine_init(SINE_FREQUENCY=7):
    def sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                # See supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-np.sqrt(6 / num_input) / SINE_FREQUENCY, np.sqrt(6 / num_input) / SINE_FREQUENCY)

    return sine_init


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)


###################
# Complex operators
def compl_conj(x):
    y = x.clone()
    y[..., 1::2] = -1 * y[..., 1::2]
    return y


def compl_div(x, y):
    ''' x / y '''
    a = x[..., ::2]
    b = x[..., 1::2]
    c = y[..., ::2]
    d = y[..., 1::2]

    outr = (a * c + b * d) / (c ** 2 + d ** 2)
    outi = (b * c - a * d) / (c ** 2 + d ** 2)
    out = torch.zeros_like(x)
    out[..., ::2] = outr
    out[..., 1::2] = outi
    return out


def compl_mul(x, y):
    '''  x * y '''
    a = x[..., ::2]
    b = x[..., 1::2]
    c = y[..., ::2]
    d = y[..., 1::2]

    outr = a * c - b * d
    outi = (a + b) * (c + d) - a * c - b * d
    out = torch.zeros_like(x)
    out[..., ::2] = outr
    out[..., 1::2] = outi
    return out
