import numpy as np
import torch.nn as nn
import torch
from torch.autograd import grad


def gradient(inputs, outputs):
    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    points_grad = grad(
        outputs=outputs,
        inputs=inputs,
        grad_outputs=d_points,
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0]
    return points_grad

class ImplicitNet(nn.Module):
    def __init__(
        self,
        d_in,
        dims,
        d_out = 1,
        skip_in=(),
        beta=1,
    ):
        super().__init__()
        dims = [d_in] + dims + [d_out]
        self.d_in = d_in
        self.num_layers = len(dims)
        self.skip_in = skip_in
        self.p = 100

        for layer in range(0, self.num_layers - 1):

            if layer + 1 in skip_in:
                out_dim = dims[layer + 1] - d_in
            else:
                out_dim = dims[layer + 1]

            lin = nn.Linear(dims[layer], out_dim)
          
            setattr(self, "lin" + str(layer), lin)
            
        if beta > 0:
            # self.activation = nn.Softplus(beta=beta)
            self.activation = nn.Tanh()

        # vanilla relu
        else:
            self.activation = nn.LeakyReLU()

    def forward(self, input, return_grad=True, return_auggrad=True):
        
        lin = getattr(self, "lin" + str(0))
        output = self.activation(lin(input))
        for layer in range(1, self.num_layers - 2):

            lin = getattr(self, "lin" + str(layer))
            
            # if layer in self.skip_in:
            #     output = torch.cat([output, input], -1) / np.sqrt(2)

            # output = lin(output)
            
            # if layer < self.num_layers - 2:
            #     output = self.activation(output)

            output = output + 1.0 * self.activation(lin(output))
        lin = getattr(self, "lin" + str(self.num_layers-2))
        output = lin(output)
        
        return output