import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

class ExplicitLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None):
        super().__init__(
            in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype
        )

    def act(self, x):
        return x

    def reverse_act(self, x):
        return torch.ones_like(x)

    def forward(self, x):
        z = F.linear(x, self.weight, self.bias)
        out = self.act(z)
        return z, out
    
    def reverse(self, dfdhl, zl):
        #out: hl = act(zl)
        #dfdhl: [N, nl]
        #return df_dzl, df_dhlm1
        #df_dhlm1: [N, nlm1]

        hlder = self.reverse_act(zl) #[N, nl]
        dfdzl = dfdhl * hlder

        df_dhlm1 = dfdzl @ self.weight

        return dfdzl, df_dhlm1


class ExplicitSineLayer(ExplicitLinear):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__(in_features=in_features, out_features=out_features, bias=bias)
        
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        
        self.init_weights()

    def act(self, x):
        return torch.sin(self.omega_0 * x)

    def reverse_act(self, x):
        return self.omega_0 * torch.cos(self.omega_0 * x)
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.weight.uniform_(
                    -1 / self.in_features, 1 / self.in_features
                )
            else:
                self.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0, 
                    np.sqrt(6 / self.in_features) / self.omega_0
                )

class ExplicitReLULayer(ExplicitLinear):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features=in_features, out_features=out_features, bias=bias)
        self.init_weights()

    def act(self, x):
        return F.relu(x)

    def reverse_act(self, x):
        return (x > 0).float()
    
    def init_weights(self):
        with torch.no_grad():
            nn.init.kaiming_normal_(self.weight, mode='fan_in', nonlinearity='relu')
            nn.init.constant_(self.bias, 0)


#Reference: https://github.com/vsitzmann/siren/blob/4df34baee3f0f9c8f351630992c1fe1f69114b5f/modules.py#L222
class ExplicitPosEncoderLayer(nn.Module):
    def __init__(self, in_features, sidelength=None):
        super().__init__()

        self.in_features = in_features
        self.num_frequencies = self.get_num_frequencies_nyquist(sidelength)
        self.out_features = in_features + 2 * in_features * self.num_frequencies
        w, b = self.make_pos_encoding_buffers()

        self.register_buffer('weight', w)
        self.register_buffer('bias', b)

    def get_num_frequencies_nyquist(self, samples):
        nyquist_rate = 1 / (2 * (2 * 1 / samples))
        return int(math.floor(math.log(nyquist_rate, 2)))

    def make_pos_encoding_buffers(self):
        w = torch.zeros(self.in_features, self.out_features - self.in_features)
        b = torch.zeros(self.out_features - self.in_features)
        inc = self.in_features * self.num_frequencies
        b[inc:] = np.pi/2
        fs = torch.arange(self.num_frequencies)
        for i in range(2):
            for j in range(self.in_features):
                w[j, i * inc + j * self.num_frequencies : i * inc + (j+1) * self.num_frequencies] = 2 ** fs * np.pi
    
        return w, b

    def forward(self, x):
        z = x @ self.weight + self.bias
        out = torch.cat([x, torch.sin(z)], dim=-1)
        return x, out

    def reverse(self, dfdhl, zl):
        #dfdzl, g = layer.reverse(g, zl)
        #out: hl = act(zl)
        #dfdhl: [N, nl]
        #return df_dzl, df_dhlm1
        #df_dhlm1: [N, nlm1]
        out = zl @ self.weight + self.bias
        s_der = torch.cos(out)

        id_dfdzl = dfdhl[:, :self.in_features]
        s_dfdzl = dfdhl[:, self.in_features:]
        
        dfdzl = id_dfdzl + s_dfdzl * s_der @ self.weight.T

        return dfdzl, dfdzl


#Reference: https://github.com/jmclong/random-fourier-features-pytorch/blob/main/rff/functional.py
class ExplicitFourierLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    #TODO:
    #Use sigma!
    def __init__(self, in_features, out_features, sigma=10.0):
        super().__init__()
        
        self.sigma = sigma
        self.in_features = in_features
        self.register_buffer('weight', torch.randn(in_features, out_features) * self.sigma)

    def act(self, x):
        vp = 2 * np.pi * x
        return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)

    def forward(self, x):
        z = x @ self.weight
        out = self.act(z)
        return z, out

    def reverse(self, dfdhl, zl):
        #dfdzl, g = layer.reverse(g, zl)
        #out: hl = act(zl)
        #dfdhl: [N, nl] #gradients of the input
        #return df_dzl, df_dhlm1
        #df_dhlm1: [N, nlm1]
        nl = zl.shape[1]

        vp = 2 * np.pi * zl
        c_der = -2 * np.pi * torch.sin(vp)
        s_der = 2 * np.pi * torch.cos(vp)

        c_dfdzl = dfdhl[:, :nl]
        s_dfdzl = dfdhl[:, nl:]
        
        dfdzl = c_dfdzl * c_der + s_dfdzl * s_der

        df_dhlm1 = dfdzl @ self.weight.T

        return dfdzl, df_dhlm1


#Reference: https://github.com/vishwa91/wire/blob/main/modules/gauss.py#L11
class ExplicitGaussLayer(ExplicitLinear):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True, sigma_0=10.0):
        super().__init__(in_features=in_features, out_features=out_features, bias=bias)
        self.sigma_0 = sigma_0
        self.in_features = in_features

    def act(self, x):
        return torch.exp(-(self.sigma_0*x)**2)

    def reverse_act(self, x):
        factor = -2 * self.sigma_0**2 * x
        return factor * torch.exp(-(self.sigma_0*x)**2)


class ExplicitSequential(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.net = self.build_net(*args, **kwargs)
        self.n_layers = len(self.net)

    def build_net(self, *args, **kwargs):
        raise NotImplementedError

    def forward(self, coords, keep_cache=False, detach=True):
        if keep_cache:
            cache = {'zs' : [], 'hs' : []}
        h = coords
        for i, layer in enumerate(self.net):
            z, h = layer(h)

            if keep_cache:
                if detach:
                    cache['zs'].append(z.detach().clone().data)
                    cache['hs'].append(h.detach().clone().data)
                else:
                    cache['zs'].append(z)
                    cache['hs'].append(h)

        if keep_cache:
            return h, cache
        else:
            return h

    def reverse(self, J, cache, detach=True):
        #J: [N, out_features]
        dfdzs = []
        g = J
        for i in range(self.n_layers - 1, -1, -1):
            layer = self.net[i]
            zl = cache['zs'][i]
            dfdzl, g = layer.reverse(g, zl)
            if detach:
                dfdzs.insert(0, dfdzl.detach().clone().data)
            else:
                dfdzs.insert(0, dfdzl)

        if detach:
            return g.data.clone(), dfdzs
        else:
            return g, dfdzs

    def jac_forward(self, cache):
        N = cache['zs'][0].shape[0]
        dzdxs = [self.net[0].weight[None, :, :].expand((N,) +  self.net[0].weight.shape)]
        for l in range(1, self.n_layers):
            #reverse_act
            wl = self.net[l].weight #[nl, nl-1]
            zlm1 = cache['zs'][l-1] #[N, nl-1]
            rzlm1 = self.net[l-1].reverse_act(zlm1)
            glm1 = dzdxs[l-1] #[N, nl-1, nx]
            
            dhlm1dx = rzlm1[:, :, None] * glm1 #[N, nl-1, nx]
            dzldx = (dhlm1dx.permute(0, 2, 1) @ wl.T).permute(0, 2, 1)
            
            dzdxs.append(dzldx.data.detach().clone())

        return dzdxs
    
    def dfdx(self, coords):
        out_size = self.net[-1].out_features
        g0 = coords.new(coords.shape[0], out_size).fill_(1.)
        
        _, cache = self.forward(coords, keep_cache=True)    
        dfdx, _ = self.reverse(g0, cache)

        return dfdx
    
    def grad_theta_mag(self, coords):
        out_size = self.net[-1].out_features
        g0 = coords.new(coords.shape[0], out_size).fill_(1.)
        
        _, cache = self.forward(coords, keep_cache=True)    
        _, dfdzs = self.reverse(g0, cache)

        hs = [coords] + cache['hs'][:-1]

        mag_theta = 0

        for i in range(self.n_layers):
            _h, _dfdz = hs[i], dfdzs[i]
            mag_dfdz = (_dfdz ** 2).sum(-1)
            mag_h = (_h ** 2).sum(-1)
            
            mag_theta += mag_dfdz * (1 + mag_h)

        return mag_theta



class ExplicitSiren(ExplicitSequential):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        
        super().__init__(in_features, hidden_features, hidden_layers, out_features, outermost_linear, 
                 first_omega_0, hidden_omega_0)

    def build_net(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear, 
        first_omega_0, hidden_omega_0):
        
        net = []
        net.append(ExplicitSineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            net.append(ExplicitSineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = ExplicitLinear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            net.append(final_linear)
        else:
            net.append(ExplicitSineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        return nn.ModuleList(net)
    

#TODO
class ExplicitFourierNet(ExplicitSequential):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, sigma=10):
        
        super().__init__(in_features, hidden_features, hidden_layers, out_features, sigma)

    def build_net(self, in_features, hidden_features, hidden_layers, out_features, sigma):
        
        net = []
        #These features will be doubled, so divide them by 2 for compatibility.
        net.append(ExplicitFourierLayer(in_features, hidden_features // 2, sigma))

        for i in range(hidden_layers):
            #def __init__(self, in_features, out_features, bias=True):
            net.append(ExplicitReLULayer(hidden_features, hidden_features))

        final_linear = ExplicitLinear(hidden_features, out_features)            
        with torch.no_grad():
            final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / sigma, 
                                            np.sqrt(6 / hidden_features) / sigma)
            
        net.append(final_linear)
        
        return nn.ModuleList(net)
    

class ExplicitPositionalReLUNet(ExplicitSequential):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, sidelength=64):
        super().__init__(in_features, hidden_features, hidden_layers, out_features, sidelength)

    def build_net(self, in_features, hidden_features, hidden_layers, out_features, sidelength):
        
        net = []
        #These features will be doubled, so divide them by 2 for compatibility.
        net.append(ExplicitPosEncoderLayer(in_features, sidelength))
        net.append(ExplicitReLULayer(net[-1].out_features, hidden_features))

        for i in range(hidden_layers - 1):
            #def __init__(self, in_features, out_features, bias=True):
            net.append(ExplicitReLULayer(hidden_features, hidden_features))

        final_linear = ExplicitLinear(hidden_features, out_features)
            
        net.append(final_linear)
        
        return nn.ModuleList(net)

#TODO
class ExplicitGaussNet(ExplicitSequential):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, sigma=10):
        
        super().__init__(in_features, hidden_features, hidden_layers, out_features, sigma)

    def build_net(self, in_features, hidden_features, hidden_layers, out_features, sigma):
        
        net = []
        #def __init__(self, in_features, out_features, sigma=10.0):
        net.append(ExplicitGaussLayer(in_features, hidden_features, sigma))

        for i in range(hidden_layers):
            #def __init__(self, in_features, out_features, bias=True):
            net.append(ExplicitGaussLayer(hidden_features, hidden_features))

        final_linear = ExplicitLinear(hidden_features, out_features)            
        with torch.no_grad():
            final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / sigma, 
                                            np.sqrt(6 / hidden_features) / sigma)
            
        net.append(final_linear)
        
        return nn.ModuleList(net)