import numpy as np
import torch.nn as nn
import torch
from torch.autograd import grad


def gradient(inputs, outputs):
    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    points_grad = grad(
        outputs=outputs,
        inputs=inputs,
        grad_outputs=d_points,
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0]
    return points_grad

class ImplicitNet(nn.Module):
    def __init__(
        self,
        d_in,
        dims,
        d_out = 1,
        skip_in=(),
        beta=100.,
    ):
        super().__init__()
        dims = [d_in] + dims + [d_out]
        self.d_in = d_in
        self.num_layers = len(dims)
        self.skip_in = skip_in
        self.p = 100

        for layer in range(0, self.num_layers - 1):

            if layer + 1 in skip_in:
                out_dim = dims[layer + 1] - d_in
            else:
                out_dim = dims[layer + 1]

            lin = nn.Linear(dims[layer], out_dim)
          
            setattr(self, "lin" + str(layer), lin)
            
        if beta > 0:
            self.activation = nn.Softplus(beta=beta)

        # vanilla relu
        else:
            self.activation = nn.LeakyReLU()

    def forward(self, input, return_grad=True, return_auggrad=True):

        output = input
        for layer in range(0, self.num_layers - 1):

            lin = getattr(self, "lin" + str(layer))
            
            if layer in self.skip_in:
                output = torch.cat([output, input], -1) / np.sqrt(2)

            output = lin(output)
            
            if layer < self.num_layers - 2:
                output = self.activation(output)
        
        return output

class ConvexQuadratic(nn.Module):
    '''Convex Quadratic Layer'''
    __constants__ = ['in_features', 'out_features', 'quadratic_decomposed', 'weight', 'bias']

    def __init__(self, in_features, out_features, bias=True, rank=1):
        super(ConvexQuadratic, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        
        self.quadratic_decomposed = nn.Parameter(torch.Tensor(
            torch.randn(in_features, rank, out_features)
        ))
        self.weight = nn.Parameter(torch.Tensor(
            torch.randn(out_features, in_features)
        ))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        quad = ((input.matmul(self.quadratic_decomposed.transpose(1,0)).transpose(1, 0)) ** 2).sum(dim=1)
        linear = torch.nn.functional.linear(input, self.weight, self.bias)
        return quad + linear    


class DenseICNN(nn.Module):
    '''Fully Conncted ICNN with input-quadratic skip connections.'''
    def __init__(
        self, dim, 
        hidden_layer_sizes=[32, 32, 32],
        rank=1, activation='celu',
        strong_convexity=1e-6,
        batch_size=20000,
        weights_init_std=0.1,
    ):
        super(DenseICNN, self).__init__()
        
        self.dim = dim
        self.strong_convexity = strong_convexity
        self.hidden_layer_sizes = hidden_layer_sizes
        self.activation = activation
        self.rank = rank
        self.batch_size = batch_size
        
        self.quadratic_layers = nn.ModuleList([
            ConvexQuadratic(dim, out_features, rank=rank, bias=True)
            for out_features in hidden_layer_sizes
        ])
        
        sizes = zip(hidden_layer_sizes[:-1], hidden_layer_sizes[1:])
        self.convex_layers = nn.ModuleList([
            nn.Linear(in_features, out_features, bias=False)
            for (in_features, out_features) in sizes
        ])
        
        self.final_layer = nn.Linear(hidden_layer_sizes[-1], 1, bias=False)
        
        self._init_weights(weights_init_std)
        
    def _init_weights(self, std):
        for p in self.parameters():
            p.data = (torch.randn(p.shape, dtype=torch.float32) * std).to(p)   

    def forward(self, input):
        '''Evaluation of the discriminator value. Preserves the computational graph.'''
        output = self.quadratic_layers[0](input)
        for quadratic_layer, convex_layer in zip(self.quadratic_layers[1:], self.convex_layers):
            output = convex_layer(output) + quadratic_layer(input)
            if self.activation == 'celu':
                output = torch.celu(output)
            elif self.activation == 'softplus':
                output = F.softplus(output)
            elif self.activation == 'relu':
                output = F.relu(output)
            else:
                raise Exception('Activation is not specified or unknown.')
        
        return self.final_layer(output) + .5 * self.strong_convexity * (input ** 2).sum(dim=1).reshape(-1, 1)
    
    def push(self, input, create_graph=True, retain_graph=True):
        '''
        Pushes input by using the gradient of the network. By default preserves the computational graph.
        Apply to small batches.
        '''
        assert len(input) <= self.batch_size
        output = torch.autograd.grad(
            outputs=self.forward(input), inputs=input,
            create_graph=create_graph, retain_graph=retain_graph,
            only_inputs=True,
            grad_outputs=torch.ones_like(input[:, :1], requires_grad=False)
        )[0]
        return output
    
    def push_nograd(self, input):
        '''
        Pushes input by using the gradient of the network. Does not preserve the computational graph.
        Use for pushing large batches (the function uses minibatches).
        '''
        output = torch.zeros_like(input, requires_grad=False)
        for i in range(0, len(input), self.batch_size):
            input_batch = input[i:i+self.batch_size]
            output.data[i:i+self.batch_size] = self.push(
                input[i:i+self.batch_size],
                create_graph=False, retain_graph=False
            ).data
        return output  
    
    def convexify(self):
        for layer in self.convex_layers:
            if (isinstance(layer, nn.Linear)):
                layer.weight.data.clamp_(0)
        self.final_layer.weight.data.clamp_(0)
        
    # def convexify(self):
    #     for layer in self.convex_layers:
    #         if isinstance(layer, nn.Linear):
    #             # Clamp only the weights except the last input dimension
    #             layer.weight.data[:, :-1].clamp_(0)
    #     self.final_layer.weight.data.clamp_(0)