from copy import copy
import math

import torch
from torch import nn

from apex._autocast_utils import _cast_if_autocast_enabled
import mlp_cuda


class MlpFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, bias, activation, *args):
        output = mlp_cuda.forward(bias, activation, args)
        ctx.save_for_backward(*args)
        ctx.outputs = output
        ctx.bias = bias
        ctx.activation = activation
        return output[0]

    @staticmethod
    def backward(ctx, grad_o):
        grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)
        del ctx.outputs
        return (None, None, *grads)


def mlp_function(bias, activation, *args):
    autocast_args = _cast_if_autocast_enabled(bias, activation, *args)
    return MlpFunction.apply(*autocast_args)


class MLP(torch.nn.Module):
    """Launch MLP in C++

    Args:
        mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024
        bias (bool): Default True:
        relu (bool): Default True
    """
    def __init__(self, mlp_sizes, bias=True, activation='relu'):
        super().__init__()
        self.num_layers = len(mlp_sizes) - 1
        self.mlp_sizes = copy(mlp_sizes)
        self.bias = 1 if bias else 0

        if activation == 'none':
            self.activation = 0
        elif activation == 'relu':
            self.activation = 1
        elif activation == 'sigmoid':
            self.activation = 2
        else:
            raise TypeError("activation must be relu or none.")

        self.weights = []
        self.biases = []
        for i in range(self.num_layers):
            w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i]))
            self.weights.append(w)
            name = 'weight_{}'.format(i)
            setattr(self, name, w)
            if self.bias:
                b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
                self.biases.append(b)
                name = 'bias_{}'.format(i)
                setattr(self, name, b)

        self.reset_parameters()

    def reset_parameters(self):
        for weight in self.weights:
            dimsum = weight.size(0) + weight.size(1)
            std = math.sqrt(2. / float(dimsum))
            nn.init.normal_(weight, 0., std)
        if self.bias:
            for bias in self.biases:
                std = math.sqrt(1. / float(bias.size(0)))
                nn.init.normal_(bias, 0., std)

    def forward(self, input):
        return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)

    def extra_repr(self):
        s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}"
        return s
