import torch.nn as nn


####################### Basic accessories setup ###############################


def get_activation(activation, leaky_relu_slope=0.6):
    if activation == 'relu':
        return nn.ReLU(True)
    elif activation == 'celu':
        return nn.CELU()
    elif activation == 'selu':
        return nn.SELU()
    elif activation == 'tanh':
        return nn.Tanh()
    elif activation == 'softsign':  # map to [-1,1]
        return nn.Softsign()
    elif activation == 'Prelu':
        return nn.PReLU()
    elif activation == 'Rrelu':
        return nn.RReLU(0.5, 0.8)
    elif activation == 'hardshrink':
        return nn.Hardshrink()
    elif activation == 'sigmoid':
        return nn.Sigmoid()
    elif activation == 'tanhshrink':
        return nn.Tanhshrink()
    else:
        raise NotImplementedError('activation [%s] is not found' % activation)

# * MLP

#! target for h/T


class Fully_connected(nn.Module):
    def __init__(self, input_dim=785, output_dim=1, hidden_dim=1024, num_layer=1, activation='Prelu', final_actv='Prelu', full_activ=True, bias=True, dropout=False, batch_nml=False, res=0, quadr=1):
        super(Fully_connected, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.num_layer = num_layer
        self.full_activ = full_activ
        self.final_actv = final_actv
        self.dropout = dropout
        self.batch_nml = batch_nml
        self.res = res
        self.quadr = quadr

        self.layer1 = nn.Linear(
            self.input_dim, self.hidden_dim, bias=bias)
        self.layer1_activ = get_activation(self.activation)
        self.linearblock = nn.ModuleList(
            [nn.Linear(self.hidden_dim, self.hidden_dim, bias=bias) for _ in range(self.num_layer)])
        self.atvt_list = nn.ModuleList(
            [get_activation(self.activation) for _ in range(self.num_layer)])
        if batch_nml:
            self.batchnormal = nn.ModuleList(
                [nn.BatchNorm1d(self.hidden_dim) for _ in range(self.num_layer)])
        if dropout > 0:
            self.dropout_list = nn.ModuleList(
                [nn.Dropout(dropout) for _ in range(self.num_layer)])

        self.last_layer = nn.Linear(
            self.hidden_dim, self.output_dim, bias=bias)
        if self.full_activ:
            self.last_layer_activ = get_activation(self.final_actv)

    def forward(self, input):

        x = self.layer1_activ(self.layer1(input))

        for i in range(self.num_layer):
            x = self.linearblock[i](x)
            if self.batch_nml:
                x = self.batchnormal[i](x)
            if self.dropout > 0:
                x = self.dropout_list[i](x)
            x = self.atvt_list[i](x)
        if self.full_activ:
            x = self.last_layer_activ(self.last_layer(x))
        else:
            x = self.last_layer(x)

        if self.res:
            return x + input
        elif self.quadr:
            return x**2
        return x


class FC_linear(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layer, res=0):
        super(FC_linear, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.res = res
        self.fc1_normal = nn.Linear(self.input_dim, self.hidden_dim)

        self.linearblock = nn.ModuleList(
            [nn.Linear(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)])

        self.last_normal = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, input):

        x = self.fc1_normal(input)

        for i in range(self.num_layer):
            x = self.linearblock[i](x)

        x = self.last_normal(x)
        if self.res:
            return x + input
        return x

# * ICNN


class ConvexLinear(nn.Linear):
    def __init__(self, *kargs, **kwargs):
        super(ConvexLinear, self).__init__(*kargs, **kwargs)

        if not hasattr(self.weight, 'be_positive'):
            self.weight.be_positive = 1.0

    def forward(self, input):
        out = nn.functional.linear(input, self.weight, self.bias)
        return out


class ICNN_LastInp_Quadratic(nn.Module):
    def __init__(self, input_dim, hidden_dim, activation, num_layer, dropout=0):
        super(ICNN_LastInp_Quadratic, self).__init__()
        # torch.set_default_dtype(torch.float64)
        # num_layer = the number excluding the last layer
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.num_layer = num_layer
        self.dropout = dropout

        self.fc1_normal = nn.Linear(self.input_dim, self.hidden_dim, bias=True)
        self.activ_1 = get_activation(self.activation)

        # begin to define my own normal and convex and activation
        self.normal = nn.ModuleList([nn.Linear(
            self.input_dim, self.hidden_dim, bias=True) for _ in range(2, self.num_layer + 1)])

        self.convex = nn.ModuleList([ConvexLinear(
            self.hidden_dim, self.hidden_dim, bias=False) for _ in range(2, self.num_layer + 1)])
        if dropout > 0:
            self.dropout_list = nn.ModuleList(
                [nn.Dropout(dropout) for _ in range(self.num_layer)])

        self.acti_list = nn.ModuleList(
            [get_activation(self.activation) for _ in range(2, self.num_layer + 1)])

        self.last_convex = ConvexLinear(self.hidden_dim, 1, bias=False)
        # self.last_linear = nn.Linear(self.input_dim, 1, bias=True)

    def forward(self, input):

        x = self.activ_1(self.fc1_normal(input)).pow(2)

        for i in range(self.num_layer - 1):
            x = self.acti_list[i](self.convex[i](
                x).add(self.normal[i](input)))
            if self.dropout > 0:
                x = self.dropout_list[i](x)

        # x = self.last_convex(x).add(self.last_linear(input).pow(2))
        x = self.last_convex(x).add((0.5 * torch.norm(input, dim=1)**2).reshape(-1, 1))

        return x


import torch
import torch.autograd as autograd
import torch.nn.functional as F


class ConvexQuadratic(nn.Module):
    '''Convex Quadratic Layer'''
    __constants__ = ['in_features', 'out_features', 'quadratic_decomposed', 'weight', 'bias']

    def __init__(self, in_features, out_features, bias=True, rank=1):
        super(ConvexQuadratic, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank

        self.quadratic_decomposed = nn.Parameter(torch.Tensor(
            torch.randn(in_features, rank, out_features)
        ))
        self.weight = nn.Parameter(torch.Tensor(
            torch.randn(out_features, in_features)
        ))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        # self.linear = nn.Linear(in_features, out_features)

    def forward(self, input):
        quad = ((input.matmul(self.quadratic_decomposed.transpose(1, 0)).transpose(1, 0)) ** 2).sum(dim=1)
        linear = F.linear(input, self.weight, self.bias)
        # linear = self.linear(input)
        return quad + linear


class WeightTransformedLinear(nn.Linear):

    def __init__(self, in_features, out_features, bias=True, w_transform=lambda x: x):
        super().__init__(in_features, out_features, bias=bias)
        self._w_transform = w_transform

    def forward(self, input):
        return F.linear(input, self._w_transform(self.weight), self.bias)


class GradNN(nn.Module):
    def __init__(self, batch_size=1024):
        super(GradNN, self).__init__()
        self.batch_size = batch_size

    def forward(self, input):
        pass

    def push(self, input, create_graph=True, retain_graph=True):
        '''
        Pushes input by using the gradient of the network. By default preserves the computational graph.
        # Apply to small batches.
        '''
        if len(input) <= self.batch_size:
            output = autograd.grad(
                outputs=self.forward(input), inputs=input,
                create_graph=create_graph, retain_graph=retain_graph,
                only_inputs=True,
                grad_outputs=torch.ones_like(input[:, :1], requires_grad=False)
            )[0]
            return output
        else:
            output = torch.zeros_like(input, requires_grad=False)
            for j in range(0, input.size(0), self.batch_size):
                output[j: j + self.batch_size] = self.push(
                    input[j:j + self.batch_size],
                    create_graph=create_graph, retain_graph=retain_graph)
            return output


class DenseICNN(GradNN):
    '''Fully Conncted ICNN with input-quadratic skip connections.'''

    def __init__(
        self, dim,
        hidden_layer_sizes=[32, 32, 32],
        rank=1, activation='celu',
        strong_convexity=1e-6,
        batch_size=1024,
        conv_layers_w_trf=lambda x: x,
        forse_w_positive=True
    ):
        super(DenseICNN, self).__init__(batch_size)

        self.dim = dim
        self.strong_convexity = strong_convexity
        self.hidden_layer_sizes = hidden_layer_sizes
        self.activation = activation
        self.rank = rank
        self.conv_layers_w_trf = conv_layers_w_trf
        self.forse_w_positive = forse_w_positive

        self.quadratic_layers = nn.ModuleList([
            ConvexQuadratic(dim, out_features, rank=rank, bias=True)
            for out_features in hidden_layer_sizes
        ])

        sizes = zip(hidden_layer_sizes[:-1], hidden_layer_sizes[1:])
        self.convex_layers = nn.ModuleList([
            WeightTransformedLinear(
                in_features, out_features, bias=False, w_transform=self.conv_layers_w_trf)
            for (in_features, out_features) in sizes
        ])

        self.final_layer = WeightTransformedLinear(
            hidden_layer_sizes[-1], 1, bias=False, w_transform=self.conv_layers_w_trf)

    def forward(self, input):
        '''Evaluation of the discriminator value. Preserves the computational graph.'''
        output = self.quadratic_layers[0](input)
        for quadratic_layer, convex_layer in zip(self.quadratic_layers[1:], self.convex_layers):
            output = convex_layer(output) + quadratic_layer(input)
            if self.activation == 'celu':
                output = torch.celu(output)
            elif self.activation == 'softplus':
                output = F.softplus(output)
            elif self.activation == 'relu':
                output = F.relu(output)
            else:
                raise Exception('Activation is not specified or unknown.')

        return self.final_layer(output) + .5 * self.strong_convexity * (input ** 2).sum(dim=1).reshape(-1, 1)

    def convexify(self):
        if self.forse_w_positive:
            for layer in self.convex_layers:
                if (isinstance(layer, nn.Linear)):
                    layer.weight.data.clamp_(0)
            self.final_layer.weight.data.clamp_(0)
