"""
# https://raw.githubusercontent.com/vsitzmann/siren/master/modules.py
"""
import torch
from torch import nn, nn as nn
from torchmeta.modules import (MetaModule, MetaSequential)
import numpy as np
from collections import OrderedDict
import math
import torch.nn.functional as F
from typing import Callable


class Sine(nn.Module):
    """
    Sinusoidal activation
    """

    def __init__(self, sine_frequency=4.):
        """
        
        :param sine_frequency:
        :return:
        """
        super().__init__()
        self.SINE_FREQUENCY = sine_frequency

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(self.SINE_FREQUENCY * input)


class FCLayer(torch.nn.Module):
    """
    Fully Connected Block, that also supports sine activations (they need a specific initialization)
    """

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_activation: Callable = None, nonlinearity="tanh", **kwargs):
        super().__init__()

        nls_and_inits = {'sine': (Sine(kwargs.get("sine_frequency", 7)),
                                  gen_sine_init(kwargs.get("sine_frequency", 7)),
                                  first_layer_sine_init),
                         'relu': (nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid': (nn.Sigmoid(), init_weights_xavier, None),
                         'tanh': (nn.Tanh(), init_weights_xavier, None),
                         'selu': (nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus': (nn.Softplus(), init_weights_normal, None),
                         'elu': (nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        self.weight_init = nl_weight_init

        self.net = []
        self.net.append(torch.nn.Sequential(
            torch.nn.Linear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(MetaSequential(
                torch.nn.Linear(hidden_features, hidden_features), nl  # , nn.BatchNorm1d(hidden_features)
            ))

        if outermost_activation:
            self.net.append(torch.nn.Sequential(
                torch.nn.Linear(hidden_features, out_features), outermost_activation()
            ))
        else:
            self.net.append(torch.nn.Sequential(
                torch.nn.Linear(hidden_features, out_features)
            ))

        self.net = torch.nn.Sequential(*self.net)

        self.net.apply(nl_weight_init)
        if first_layer_init is not None:  # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords, **kwargs):
        return self.net(coords)


class BatchLinear(nn.Linear, MetaModule):
    '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
    hypernetwork.'''
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


########################
# Initialization methods
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    # grab from upstream pytorch branch and paste here for now
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def init_weights_trunc_normal(m):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            fan_in = m.weight.size(1)
            fan_out = m.weight.size(0)
            std = math.sqrt(2.0 / float(fan_in + fan_out))
            mean = 0.
            # initialize with the same behavior as tf.truncated_normal
            # "The generated values follow a normal distribution with specified mean and
            # standard deviation, except that values whose magnitude is more than 2
            # standard deviations from the mean are dropped and re-picked."
            _no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)


def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def gen_sine_init(SINE_FREQUENCY=7):
    def sine_init(m):
        with torch.no_grad():
            if hasattr(m, 'weight'):
                num_input = m.weight.size(-1)
                # See supplement Sec. 1.5 for discussion of factor 30
                m.weight.uniform_(-np.sqrt(6 / num_input) / SINE_FREQUENCY, np.sqrt(6 / num_input) / SINE_FREQUENCY)

    return sine_init


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)


class FourierEmbedding(nn.Module):
    def __init__(self, input_dim, num_frequencies=20):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.out_features = self.num_frequencies * 2
        self.f = torch.nn.Linear(in_features=input_dim, out_features=num_frequencies,
                                 bias=False)

    def forward(self, x):
        freq = 2 * np.pi * self.f(x)
        return torch.cat([torch.sin(freq), torch.cos(freq)], -1)


class NVidiaAttention(nn.Module):
    def __init__(self, input_dim, num_layers=3, embedding_dim=40, num_frequencies=20):
        super().__init__()
        self.in_features = input_dim
        self.embedding_dim = embedding_dim
        self.out_features = self.embedding_dim
        self.activation = torch.nn.Softplus()

        self.fourier_embed = FourierEmbedding(input_dim, num_frequencies=num_frequencies)
        self.transform1 = torch.nn.Sequential(torch.nn.Linear(self.fourier_embed.out_features, embedding_dim),
                                              self.activation)
        self.transform2 = torch.nn.Sequential(torch.nn.Linear(self.fourier_embed.out_features, embedding_dim),
                                              self.activation)

        z_nets = [torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim),
                                      self.activation)] + [
                     torch.nn.Sequential(torch.nn.Linear(embedding_dim, embedding_dim),
                                         self.activation)
                     for i in range(num_layers - 1)]
        self.Z_nets = torch.nn.ModuleList(z_nets)

    def forward(self, x):
        fourier_embedding = self.fourier_embed(x)
        t1 = self.transform1(fourier_embedding)
        t2 = self.transform2(fourier_embedding)
        for z_net in self.Z_nets:
            z = z_net(x)
            x = (1 - z) * t1 + z * t2
        return x


class AttentionLayer(nn.Module):
    def __init__(self, input_dim, num_layers=3, embedding_dim=40):
        super().__init__()
        self.in_features = input_dim
        self.embedding_dim = embedding_dim
        self.out_features = embedding_dim
        self.activation = torch.nn.Softplus()

        self.query_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)
        self.key_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)
        self.val_net = torch.nn.Sequential(torch.nn.Linear(self.in_features, embedding_dim), self.activation)

        self.Z_nets = torch.nn.ModuleList(
            torch.nn.Sequential(torch.nn.Linear(embedding_dim, embedding_dim), self.activation)
            for i in range(num_layers))

    def forward(self, x):
        U, H, V = self.key_net(x), self.query_net(x), self.val_net(x)
        for zn in self.Z_nets:
            Z = zn(H)
            H = (1 - Z) * U + Z * V
        return H


class FCBlock(MetaModule):
    '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
    Can be used just as a normal neural network though, as well.
    '''
    default_sine_freq = 7

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, outermost_positive=False, nonlinearity='relu', weight_init=None,
                 **kwargs
                 ):
        super().__init__()

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine': (Sine(kwargs.get("sine_frequency", FCBlock.default_sine_freq)),
                                  gen_sine_init(kwargs.get("sine_frequency", FCBlock.default_sine_freq)),
                                  first_layer_sine_init),
                         'relu': (nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid': (nn.Sigmoid(), init_weights_xavier, None),
                         'tanh': (nn.Tanh(), init_weights_xavier, None),
                         'selu': (nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus': (nn.Softplus(), init_weights_normal, None),
                         'elu': (nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.net = []
        self.net.append(MetaSequential(
            BatchLinear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, hidden_features), nl  # , nn.BatchNorm1d(hidden_features)
            ))

        if outermost_linear:
            # no last activation
            if not outermost_positive:
                self.net.append(MetaSequential(BatchLinear(hidden_features, out_features)))
            else:
                # additional force output to be positive
                self.net.append(MetaSequential(
                    BatchLinear(hidden_features, out_features), torch.nn.Softplus()  # , ExpLayer()  # ExpLayer()
                ))
        else:
            # last layer has activation
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, out_features), nl
            ))

        self.net = MetaSequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None:  # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords, params=None, **kwargs):
        if params is None:
            params = OrderedDict(self.named_parameters())

        output = self.net(coords, params=self.get_subdict(params, 'net'))
        return output

    def forward_with_activations(self, coords, params=None, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.'''
        if params is None:
            params = OrderedDict(self.named_parameters())

        activations = OrderedDict()

        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            subdict = self.get_subdict(params, 'net.%d' % i)
            for j, sublayer in enumerate(layer):
                if isinstance(sublayer, BatchLinear):
                    x = sublayer(x, params=self.get_subdict(subdict, '%d' % j))
                else:
                    x = sublayer(x)

                if retain_grad:
                    x.retain_grad()
                activations['_'.join((str(sublayer.__class__), "%d" % i))] = x
        return activations
