import torch
import torch.nn as nn

from torch.nn.utils.parametrizations import weight_norm

class FNN(nn.Module):
    """Feedforward neural network (FNN) model.

    Attributes
    ----------
    model : torch.nn.Module
        Sequential module implementing the FNN.
    """

    def __init__(self, layers=[1, 32, 32, 1], activation='relu'):
        """
        Parameters
        ----------
        layers : list of ints
            FNN architecture, including input and output layers.
        activation : str ('relu', 'tanh')
            String specifying activation function.
        """
        super().__init__()

        self.model = nn.Sequential()
        

        if layers[0] == 0:
            # If input size is zero, create constant layer
            # i.e. weights are set to zero, so that only biases are effectively trained
            self.constant_net = True
            constant_layer = nn.Linear(1, layers[-1])
            torch.nn.init.constant_(constant_layer.weight, 0.)
            self.model.append(constant_layer)
        else:
            # Otherwise, create normal feedorward neural network            
            self.constant_net = False
            # add input layer and hidden layers
            for n1, n2 in zip(layers[:-2], layers[1:-1]):
                self.model.append(nn.Linear(n1, n2))
                if activation == 'relu':
                    self.model.append(nn.ReLU())
                elif activation == 'tanh':
                    self.model.append(nn.Tanh())
                else:
                    raise ValueError(f'Unknown activation function: {activation}')
            
            # add linear output layer
            self.model.append(nn.Linear(layers[-2], layers[-1]))

    def forward(self, x):
        if self.constant_net:
            x = torch.zeros(x.shape[0], 1)
        return self.model(x)


class CustomActivation(nn.Module):

    def __init__(self, alpha=0.05):
        super().__init__()
        self.alpha = alpha

    def forward(self, x):
        return 0.5*nn.functional.leaky_relu(1+x, negative_slope=self.alpha-1.) - 0.5*nn.functional.leaky_relu(1-x, negative_slope=self.alpha-1.) 


class NodeNetwork:

    def __init__(self, weight1, weight2, bias1, bias2, alpha):
        """ Constructor.

        Parameters:
        -----------
        weight1 : tensor, shape (batch_size, nh)
        weight2 : tensor, shape (batch_size, nh)
        bias1 : tensor, shape (batch_size, nh)
        bias2 : tensor, shape (batch_size, 1)
        """
        
        self.weight1 = weight1
        self.weight2 = weight2
        self.bias1 = bias1
        self.bias2 = bias2
        self.alpha = alpha
        self.act = CustomActivation(alpha=alpha)

        self.bs = weight1.shape[0]
        self.nh = weight1.shape[1]

    def multi_batch_forward(self, x):
        """ Multi-batch model forward.
    
        This function uses broadcasting to compute efficiently the model output
        on (batch, subbatch) inputs x, where each (batch, subbatch) input
        is processed using the parameters of each batch.
    
        Arguments:
        ----------
        x : tensor, shape (batch_size, sub_batch_size, 1, 1)

        Returns:
        --------
        y : tensor, shape (batch_size, sub_batch_size, 1, 1)
            Network output for each (batch, subbatch) using batch's parameters.
        """

        # Reshape parameters to take advantage of broadcasting on sub_batch dimension

        w1 = self.weight1.reshape(self.bs, 1, self.nh, 1)
        w2 = self.weight2.reshape(self.bs, 1, self.nh, 1)
        b1 = self.bias1.reshape(self.bs, 1, self.nh, 1)
        b2 = self.bias2.reshape(self.bs, 1, 1, 1)

        y = torch.sum(w2*self.act(w1*x + b1), dim=2)[:, :, :, None] + b2
        
        return y

    def get_weights_and_biases(self, y):
        """ Get weights and biases of final linear segment for each output y.

        Arguments:
        ----------
        y : tensor, shape (bs, 1)
            Model output.
        """

        breakpoints = torch.stack((torch.div(-1 - self.bias1, self.weight1),
                                   torch.div(1 - self.bias1, self.weight1)), dim=2) # shape (bs, nh, 2)

        fnc_at_breakpoints = torch.empty((self.bs, self.nh, 4))
        fnc_at_breakpoints[:, :, 0] = -torch.inf
        fnc_at_breakpoints[:, :, 3] = torch.inf
        fnc_at_breakpoints[:, :, 1:3] = self.multi_batch_forward(breakpoints.reshape(self.bs, self.nh*2, 1, 1)).reshape(self.bs, self.nh, 2)

        indices = torch.diff(fnc_at_breakpoints <= y[:, :, None], dim=2) # shape (bs, nh, 3)
        
        weights = torch.stack((0.5*self.alpha*torch.mul(self.weight1, self.weight2),
                               torch.mul(self.weight1, self.weight2),
                               0.5*self.alpha*torch.mul(self.weight1, self.weight2)), dim=2) # shape (bs, nh, 3)
        
        biases = torch.stack((torch.mul(self.weight2, 0.5*self.alpha*(self.bias1 + 1.) - 1.),
                              torch.mul(self.weight2, self.bias1),
                              torch.mul(self.weight2, 0.5*self.alpha*(self.bias1 - 1.) + 1.)), dim=2) # shape (bs, nh, 3)

        weights = torch.sum(weights*indices, dim=(1,2))
        biases = torch.sum(biases*indices, dim=(1,2)) + self.bias2.reshape(-1)
        
        return weights, biases
        

    def forward(self, x):
        """ Forward of model using batch of parameters. 

        Arguments:
        ----------
        x : tensor, shape (batch_size, 1)
        """

        y = self.multi_batch_forward(x.reshape(-1, 1, 1, 1)).reshape(-1, 1)
        weights, _ = self.get_weights_and_biases(y)
        log_det = torch.log(weights)
        
        return y, log_det

    def inverse(self, y):

        weights, biases = self.get_weights_and_biases(y)
        x = ((y.reshape(-1) - biases)/weights)[:, None]
        log_det = -torch.log(weights)
        
        return x, log_det


#################################################################################################################
### Time Convolutional Network
#################################################################################################################

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

class TCN(nn.Module):
    
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super().__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)
        self.init_weights()

    def init_weights(self):
        self.linear.weight.data.normal_(0, 0.01)

    def forward(self, x):
        y = self.tcn(x)
        return self.linear(y.transpose(1, 2)) # returns shape (batch, length, output_size)