#!/usr/bin/env python

import pdb
import math

import numpy as np

import torch
from torch import nn
from torchmeta.modules import MetaModule, MetaSequential, MetaLinear

class GaussLayer(MetaModule):
    '''
        Drop in replacement for SineLayer but with Gaussian non linearity
    '''
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30, scale=10.0):
        '''
            is_first, and omega_0 are not used.
        '''
        super().__init__()
        self.in_features = in_features
        self.omega_0 = omega_0
        self.scale = scale
        self.is_first = is_first
        self.linear = MetaLinear(in_features, out_features, bias=bias)
        
    def forward(self, input, params=None):
        return torch.exp(-(self.scale*self.linear(input, params=self.get_subdict(params, 'linear')))**2)
    

class INR(MetaModule):
    def __init__(self, in_features,
                 hidden_features, hidden_layers, 
                 out_features,outermost_linear=True,
                 first_omega_0=30, hidden_omega_0=30., scale=10.0,
                 pos_encode=False, sidelength=512, fn_samples=None,
                 use_nyquist=True):
        super().__init__()
        self.pos_encode = pos_encode
        
        self.complex = False
        self.nonlin = GaussLayer
            
        self.net = []
        self.net.append(self.nonlin(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0,
                                  scale=scale))

        for i in range(hidden_layers):
            self.net.append(self.nonlin(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0,
                                      scale=scale))

        if outermost_linear:
            if self.complex:
                dtype = torch.cfloat
            else:
                dtype = torch.float
            final_linear = MetaLinear(hidden_features,
                                     out_features,
                                     dtype=dtype)
                        
            self.net.append(final_linear)
        else:
            self.net.append(self.nonlin(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0,
                                      scale=scale))
        
        self.net = MetaSequential(*self.net)
    
    def forward(self, coords, params=None):
        output = self.net(coords, params=self.get_subdict(params, 'net'))
                    
        return output