import math

import torch
import torch.nn as nn

from models.metamodule import MetaModule, MetaSequential, MetaBatchLinear


class Sine(nn.Module):
    def __init__(self, w0=30.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0*x)


class PositionalEncoding(nn.Module):
    """
    Positional Encoding of the input coordinates.

    encodes x to (..., sin(2^k x), cos(2^k x), ...)
    k takes "num_freqs" number of values equally spaced between [0, max_freq]
    """
    def __init__(self, max_freq, num_freqs):
        super().__init__()
        freqs = 2**torch.linspace(0, max_freq, num_freqs)
        self.register_buffer("freqs", freqs) #(num_freqs)

    def forward(self, x):
        x_proj = x.unsqueeze(dim=-2)*self.freqs.unsqueeze(dim=-1) #(num_rays, num_samples, num_freqs, in_features)
        x_proj = x_proj.reshape(*x.shape[:-1], -1) #(num_rays, num_samples, num_freqs*in_features)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) #(num_rays, num_samples, 2*num_freqs*in_features)
        return out


class MetaSirenLayer(MetaModule):
    """
    Single layer of SIREN; uses SIREN-style init. scheme.
    """
    def __init__(self, dim_in, dim_out, w0=30., c=6., is_first=False, is_final=False, w0_type='uniform'):
        super().__init__()
        # Encapsulates MetaLinear and activation.
        self.linear = MetaBatchLinear(dim_in, dim_out)
        self.activation = nn.Identity() if is_final else Sine(w0)
        self.w0_type = w0_type
        # Initializes according to SIREN init.
        self.init_(c=c, w0=w0, is_first=is_first)

    def init_(self, c, w0, is_first):
        dim_in = self.linear.weight.size(1)
        w_std = 1/dim_in if is_first else (math.sqrt(c/dim_in)/w0)
        if self.w0_type == 'uniform':
            nn.init.uniform_(self.linear.weight, -w_std, w_std)
            nn.init.uniform_(self.linear.bias, -w_std, w_std)
        elif self.w0_type == 'sparse':
            nn.init.sparse_(self.linear.weight, sparsity=0.1)
            nn.init.uniform_(self.linear.bias, -w_std, w_std)
        elif self.w0_type == 'orthogonal':
            nn.init.orthogonal_(self.linear.weight)
            nn.init.uniform_(self.linear.bias, -w_std, w_std)

    def forward(self, x, params=None):
        return self.activation(self.linear(x, self.get_subdict(params, 'linear')))


class MetaReLULayer(MetaModule):
    """
    Single layer of RELU; uses RELU-style init. scheme.
    """
    def __init__(self, dim_in, dim_out, w0=30., c=6., is_first=False, is_final=False):
        super().__init__()
        # Encapsulates MetaLinear and activation.
        self.linear = MetaBatchLinear(dim_in, dim_out)
        self.activation = nn.Identity() if is_final else nn.ReLU()
        # Initializes according to SIREN init.
        self.init_(c=c, w0=w0, is_first=is_first)

    def init_(self, c, w0, is_first):
        nn.init.kaiming_normal_(self.linear.weight, nonlinearity='relu')
        self.linear.bias.data.fill_(0.)
        # nn.init.kaiming_uniform_(self.linear.weight, a=math.sqrt(5))
        # fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.linear.weight)
        # bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        # nn.init.uniform_(self.linear.bias, -bound, bound)


    def forward(self, x, params=None):
        return self.activation(self.linear(x, self.get_subdict(params, 'linear')))


class MetaReLU(MetaModule):
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers=5, w0=30., w0_initial=30.,
                 data_type='img', data_size=(128, 128, 3), w0_type='uniform'):
        super().__init__()
        self.num_layers = num_layers
        self.dim_hidden = dim_hidden
        self.data_type = data_type
        self.w0 = w0
        layers = []
        layers.append(PositionalEncoding(8, 20))
        for ind in range(num_layers-1):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = 2*dim_in*20 if is_first else dim_hidden
            layers.append(MetaReLULayer(dim_in=layer_dim_in, dim_out=dim_hidden,
                                         w0=layer_w0, is_first=is_first))
        layers.append(MetaReLULayer(dim_in=dim_hidden, dim_out=dim_out,
                                        w0=w0, is_final=True))
        self.layers = MetaSequential(*layers)

    def forward(self, x, params=None):
        return self.layers(x, params=self.get_subdict(params, 'layers')) + 0.5


class MetaSiren(MetaModule):
    """
    SIREN as a meta-network.
    """
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers=4, w0=30., w0_initial=30.,
                 data_type='img', data_size=(3, 178, 178), w0_type='uniform'):
        super().__init__()
        self.num_layers = num_layers
        self.dim_hidden = dim_hidden
        self.w0 = w0
        layers = []
        for ind in range(num_layers-1):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden
            layers.append(MetaSirenLayer(dim_in=layer_dim_in, dim_out=dim_hidden,
                                         w0=layer_w0, is_first=is_first, w0_type=w0_type))
        layers.append(MetaSirenLayer(dim_in=dim_hidden, dim_out=dim_out,
                                     w0=w0, is_final=True, w0_type=w0_type))
        self.layers = MetaSequential(*layers)

    def forward(self, x, params=None):
        return self.layers(x, params=self.get_subdict(params, 'layers')) + 0.5


class MetaSirenPenultimate(MetaModule):
    """
    SIREN as a meta-network.
    """
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers=4, w0=30., w0_initial=30.,
                 data_type='img', data_size=(3, 178, 178)):
        super().__init__()
        self.num_layers = num_layers
        self.dim_hidden = dim_hidden
        self.w0 = w0
        layers = []
        for ind in range(num_layers-1):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden
            layers.append(MetaSirenLayer(dim_in=layer_dim_in, dim_out=dim_hidden,
                                         w0=layer_w0, is_first=is_first))
        self.layers = MetaSequential(*layers)
        self.last_layer = MetaSirenLayer(dim_in=dim_hidden, dim_out=dim_out,
                                         w0=w0, is_final=True)

    def forward(self, x, params=None, get_features=False):
        feature = self.layers(x, params=self.get_subdict(params, 'layers'))
        out = self.last_layer(feature, params=self.get_subdict(params, 'last_layer')) + 0.5

        if get_features:
            return out, feature
        else:
            return out
