import torch
from torch import nn
#from torchmeta.modules import (MetaModule, MetaSequential)
#from torchmeta.modules.utils import get_subdict
import numpy as np
from collections import OrderedDict
import math
import torch.nn.functional as F
import torch
from torch import nn
from collections import OrderedDict

def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)


def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
    return div


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(30 * input)



class BatchLinear(nn.Linear, ): #MetaModule):
    #A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
    #hypernetwork.
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']
        #print(input.shape)
        #print(weight.shape)

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30.):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.bn = nn.LayerNorm(out_features)
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features ,
                                            1 / self.in_features )
                #self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                #                            np.sqrt(6 / self.in_features) / self.omega_0)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                            np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        #print(input.shape)
        #print(self.in_features)
        #print(self.out_features)
        x = torch.sin(self.omega_0 * self.linear(input))
        #x = F.dropout(x, p=0.05, training=self.training)

        return x #/self.omega_0

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate

class BaseDeepSDFSiren(nn.Module):
    def __init__(self,
                 in_features,
                 latent_size,
                 hidden_features,
                 hidden_layers,
                 out_features,
                 latent_in=[4],
                 outermost_linear=False,
                 first_omega_0=30,
                 hidden_omega_0=30.,
                 zero_init_last_layer=False):
        super().__init__()

        self.in_features = in_features
        self.net = []
        self.latent_in = latent_in
        self.net.append(SineLayer(in_features + latent_size, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            if i + 1 in self.latent_in:
                self.net.append(SineLayer(hidden_features, hidden_features - in_features,
                                      is_first=False, omega_0=hidden_omega_0))#, composer=self.composer))
            else:
                self.net.append(SineLayer(hidden_features, hidden_features,
                                          is_first=False, omega_0=hidden_omega_0))  # , composer=self.composer))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            #with torch.no_grad():
            #    final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
            #                                 np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
             self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features), nn.Tanh()))


        self.net = nn.Sequential(*self.net)
        print(self.net)
        '''
        if zero_init_last_layer:
            if outermost_linear:
                torch.nn.init.constant_(self.net[-1].weight, 0)
                torch.nn.init.constant_(self.net[-1].bias, 0)
            else:
                torch.nn.init.constant_(self.net[-2].weight, 0)
                torch.nn.init.constant_(self.net[-2].bias, 0)
        else:
            if outermost_linear:
                nn.utils.spectral_norm(self.net[-1])
            else:
                nn.utils.spectral_norm(self.net[-1][0])
        '''

    def forward(self, embedding, coords):
        #coords = coords.clone().detach().requires_grad_(True)  # allows to take derivative w.r.t. input
        model_input = torch.cat((coords, embedding), dim=-1)

        for net_i in range(len(self.net)-1):
            output = self.net[net_i](model_input)
            if net_i in self.latent_in:
                model_input = torch.cat((coords, output), dim=-1)
            else:
                model_input = output

        output = self.net[-1](model_input)
        return output#, coords




class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, composer, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30., zero_init_last_layer=False):
        super().__init__()

        self.in_features = in_features
        self.net = []
        #self.composer = composer
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=False, omega_0=first_omega_0))#, composer=self.composer))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))#, composer=self.composer))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            self.net.append(final_linear)
        else:
            #self.net.append(SineLayer(hidden_features, out_features,
            #                          is_first=False, omega_0=hidden_omega_0))
             self.net.append(nn.Linear(hidden_features, out_features))
             self.net.append(nn.Tanh())

        self.net = nn.Sequential(*self.net)
        '''
        if zero_init_last_layer:
            if outermost_linear:
                torch.nn.init.constant_(self.net[-1].weight, 0)
                torch.nn.init.constant_(self.net[-1].bias, 0)
            else:
                torch.nn.init.constant_(self.net[-2].weight, 0)
                torch.nn.init.constant_(self.net[-2].bias, 0)
        else:
            if outermost_linear:
                nn.utils.spectral_norm(self.net[-1])
            else:
                nn.utils.spectral_norm(self.net[-2])
        '''



    def forward(self, coords):
        #coords = coords.clone().detach().requires_grad_(True)  # allows to take derivative w.r.t. input


        output = self.net(coords)
        return output#, coords

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)

                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()

                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else:
                x = layer(x)

                if retain_grad:
                    x.retain_grad()

            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations


class BaseFCBlock(nn.Module):
    '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
    Can be used just as a normal neural network though, as well.
    '''

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='relu', weight_init=None):
        super().__init__()

        #self.first_layer_init = None
        self.in_features = in_features
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init),
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(alpha=0.2,inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.net = []
        self.net.append(nn.Sequential(
            nn.Linear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(
                nn.Linear(hidden_features, hidden_features), nl
            ))

        if outermost_linear:
            self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))
        else:
            self.net.append(nn.Sequential(
                nn.Linear(hidden_features, out_features), nn.Tanh(),
            ))

        self.net = nn.Sequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)


    def forward(self, coords):
        #coords = coords.clone().detach().requires_grad_(True)  # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output#, coords

    def forward_with_activations(self, coords, params=None, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.'''
        if params is None:
            params = OrderedDict(self.named_parameters())

        activations = OrderedDict()

        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            subdict = self.get_subdict(params, 'net.%d' % i)
            for j, sublayer in enumerate(layer):
                #if isinstance(sublayer, BatchLinear):
                #    x = sublayer(x, params=self.get_subdict(subdict, '%d' % j))
                #else:
                x = sublayer(x)

                if retain_grad:
                    x.retain_grad()
                activations['_'.join((str(sublayer.__class__), "%d" % i))] = x
        return activations

class LipBoundedPosEnc(nn.Module):

    def __init__(self, inp_features, n_freq=5, cat_inp=True):
        super().__init__()
        self.inp_feat = inp_features #- 1
        self.n_freq = n_freq
        self.cat_inp = cat_inp
        self.out_dim = 2 * self.n_freq * self.inp_feat #+ 1
        if self.cat_inp:
            self.out_dim += self.inp_feat

    def forward(self, x):
        #z = x[..., [-1]]
        x = x[..., [0, 1, 2]]

        """
        :param x: (bs, npoints, inp_features)
        :return: (bs, npoints, 2 * out_features + inp_features)
        """
        # assert len(x.size()) == 3
        bs, npts = x.size(0), x.size(1)
        const = (2 ** torch.arange(self.n_freq) * np.pi).view(1, 1, 1, -1)
        const = const.to(x)

        # Out shape : (bs, npoints, out_feat)
        cos_feat = torch.cos(const * x.unsqueeze(-1)).view(
            bs, npts, self.inp_feat, -1)
        sin_feat = torch.sin(const * x.unsqueeze(-1)).view(
            bs, npts, self.inp_feat, -1)
        out = torch.cat(
            [sin_feat, cos_feat], dim=-1).view(
            bs, npts, 2 * self.inp_feat * self.n_freq)
        const_norm = torch.cat(
            [const, const], dim=-1).view(
            1, 1, 1, self.n_freq * 2).expand(
            -1, -1, self.inp_feat, -1).reshape(
            1, 1, 2 * self.inp_feat * self.n_freq)

        if self.cat_inp:
            out = torch.cat([out, x], dim=-1)
            const_norm = torch.cat(
                [const_norm, torch.ones(1, 1, self.inp_feat).to(x)], dim=-1)

            xyz_out =  out / const_norm / np.sqrt(self.n_freq * 2 + 1)
        else:

            xyz_out = out / const_norm / np.sqrt(self.n_freq * 2)

        #xyz_out = torch.cat((xy_out, z), dim=-1)
        return xyz_out


class BaseDeepSDF(nn.Module):
    '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
    Can be used just as a normal neural network though, as well.
    '''

    def __init__(self,
                 in_features,
                 latent_size,
                 out_features,
                 num_hidden_layers,
                 hidden_features,
                 latent_in=[4],
                 outermost_linear=False,
                 nonlinearity='relu',
                 weight_init=None):
        super().__init__()

        #self.first_layer_init = None
        self.in_features = in_features
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init),
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(alpha=0.2,inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]
        bn = nn.LayerNorm(hidden_features)
        bn_ = nn.LayerNorm(hidden_features-in_features)
        dp = nn.Dropout(p=0.2)
        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.latent_in = latent_in
        self.net = []
        self.net.append(nn.Sequential(
            nn.Linear(in_features + latent_size, hidden_features), bn, nl, dp
        ))

        for i in range(num_hidden_layers):
            if i+1 in self.latent_in:
                self.net.append(nn.Sequential(
                    nn.Linear(hidden_features, hidden_features - in_features), bn_, nl, dp
                ))
            else:
                self.net.append(nn.Sequential(
                    nn.Linear(hidden_features, hidden_features), bn, nl, dp
                ))

        if outermost_linear:
            self.net.append(nn.Sequential(nn.Linear(hidden_features,  out_features)))
        else:
            self.net.append(nn.Sequential(
                nn.Linear(hidden_features, out_features), nn.Tanh(),
            ))

        self.net = nn.Sequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)


    def forward(self, embedding, coords):
        model_input = torch.cat((coords, embedding), dim=-1)
        for net_i in range(len(self.net)-1):
            output = self.net[net_i](model_input)
            if net_i in self.latent_in:
                model_input = torch.cat((coords, output), dim=-1)
            else:
                model_input = output
        output = self.net[-1](model_input)
        return output#, coords

    def forward_with_activations(self, coords, params=None, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.'''
        if params is None:
            params = OrderedDict(self.named_parameters())

        activations = OrderedDict()

        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            subdict = self.get_subdict(params, 'net.%d' % i)
            for j, sublayer in enumerate(layer):
                #if isinstance(sublayer, BatchLinear):
                #    x = sublayer(x, params=self.get_subdict(subdict, '%d' % j))
                #else:
                x = sublayer(x)

                if retain_grad:
                    x.retain_grad()
                activations['_'.join((str(sublayer.__class__), "%d" % i))] = x
        return activations








def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)


###################
# Complex operators
def compl_conj(x):
    y = x.clone()
    y[..., 1::2] = -1 * y[..., 1::2]
    return y


def compl_div(x, y):
    ''' x / y '''
    a = x[..., ::2]
    b = x[..., 1::2]
    c = y[..., ::2]
    d = y[..., 1::2]

    outr = (a * c + b * d) / (c ** 2 + d ** 2)
    outi = (b * c - a * d) / (c ** 2 + d ** 2)
    out = torch.zeros_like(x)
    out[..., ::2] = outr
    out[..., 1::2] = outi
    return out


def compl_mul(x, y):
    '''  x * y '''
    a = x[..., ::2]
    b = x[..., 1::2]
    c = y[..., ::2]
    d = y[..., 1::2]

    outr = a * c - b * d
    outi = (a + b) * (c + d) - a * c - b * d
    out = torch.zeros_like(x)
    out[..., ::2] = outr
    out[..., 1::2] = outi
    return out


